You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2020/08/26 02:15:40 UTC
[arrow] branch master updated: ARROW-9849: [Rust] [DataFusion]
Simplified argument types of ScalarFunctions.
This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new d02e166 ARROW-9849: [Rust] [DataFusion] Simplified argument types of ScalarFunctions.
d02e166 is described below
commit d02e16646cb3ee2a568dfeeaa68a0f75986592c8
Author: Jorge C. Leitao <jo...@gmail.com>
AuthorDate: Tue Aug 25 20:15:16 2020 -0600
ARROW-9849: [Rust] [DataFusion] Simplified argument types of ScalarFunctions.
Deprecates "Field" as argument to the UDF declaration, since we are only using its type.
This is a spin-off of #8032 with a much smaller scope, as the other one is getting to large to handle.
Closes #8045 from jorgecarleitao/clean_args
Authored-by: Jorge C. Leitao <jo...@gmail.com>
Signed-off-by: Andy Grove <an...@gmail.com>
---
rust/datafusion/src/execution/context.rs | 5 +----
.../datafusion/src/execution/physical_plan/math_expressions.rs | 6 +++---
rust/datafusion/src/execution/physical_plan/mod.rs | 4 ++--
rust/datafusion/src/execution/physical_plan/udf.rs | 10 +++++-----
rust/datafusion/src/optimizer/type_coercion.rs | 3 +--
rust/datafusion/src/sql/planner.rs | 8 +++-----
rust/datafusion/tests/sql.rs | 2 +-
7 files changed, 16 insertions(+), 22 deletions(-)
diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs
index 7b0c8e9..e3d5820 100644
--- a/rust/datafusion/src/execution/context.rs
+++ b/rust/datafusion/src/execution/context.rs
@@ -996,10 +996,7 @@ mod tests {
let my_add = ScalarFunction::new(
"my_add",
- vec![
- Field::new("a", DataType::Int32, true),
- Field::new("b", DataType::Int32, true),
- ],
+ vec![DataType::Int32, DataType::Int32],
DataType::Int32,
myfunc,
);
diff --git a/rust/datafusion/src/execution/physical_plan/math_expressions.rs b/rust/datafusion/src/execution/physical_plan/math_expressions.rs
index 97098d6..ea40ac5 100644
--- a/rust/datafusion/src/execution/physical_plan/math_expressions.rs
+++ b/rust/datafusion/src/execution/physical_plan/math_expressions.rs
@@ -21,7 +21,7 @@ use crate::error::ExecutionError;
use crate::execution::physical_plan::udf::ScalarFunction;
use arrow::array::{Array, ArrayRef, Float64Array, Float64Builder};
-use arrow::datatypes::{DataType, Field};
+use arrow::datatypes::DataType;
use std::sync::Arc;
@@ -29,7 +29,7 @@ macro_rules! math_unary_function {
($NAME:expr, $FUNC:ident) => {
ScalarFunction::new(
$NAME,
- vec![Field::new("n", DataType::Float64, true)],
+ vec![DataType::Float64],
DataType::Float64,
Arc::new(|args: &[ArrayRef]| {
let n = &args[0].as_any().downcast_ref::<Float64Array>();
@@ -86,7 +86,7 @@ mod tests {
execution::context::ExecutionContext,
logicalplan::{col, sqrt, LogicalPlanBuilder},
};
- use arrow::datatypes::Schema;
+ use arrow::datatypes::{Field, Schema};
#[test]
fn cast_i8_input() -> Result<()> {
diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs
index f79907d..780cdba 100644
--- a/rust/datafusion/src/execution/physical_plan/mod.rs
+++ b/rust/datafusion/src/execution/physical_plan/mod.rs
@@ -26,7 +26,7 @@ use crate::error::Result;
use crate::execution::context::ExecutionContextState;
use crate::logicalplan::{LogicalPlan, ScalarValue};
use arrow::array::ArrayRef;
-use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+use arrow::datatypes::{DataType, Schema, SchemaRef};
use arrow::{
compute::kernels::length::length,
record_batch::{RecordBatch, RecordBatchReader},
@@ -138,7 +138,7 @@ pub trait Accumulator: Debug {
pub fn scalar_functions() -> Vec<ScalarFunction> {
let mut udfs = vec![ScalarFunction::new(
"length",
- vec![Field::new("n", DataType::Utf8, true)],
+ vec![DataType::Utf8],
DataType::UInt32,
Arc::new(|args: &[ArrayRef]| Ok(Arc::new(length(args[0].as_ref())?))),
)];
diff --git a/rust/datafusion/src/execution/physical_plan/udf.rs b/rust/datafusion/src/execution/physical_plan/udf.rs
index fb777f9..ca59087 100644
--- a/rust/datafusion/src/execution/physical_plan/udf.rs
+++ b/rust/datafusion/src/execution/physical_plan/udf.rs
@@ -20,7 +20,7 @@
use std::fmt;
use arrow::array::ArrayRef;
-use arrow::datatypes::{DataType, Field, Schema};
+use arrow::datatypes::{DataType, Schema};
use crate::error::Result;
use crate::execution::physical_plan::PhysicalExpr;
@@ -38,7 +38,7 @@ pub struct ScalarFunction {
/// Function name
pub name: String,
/// Function argument meta-data
- pub args: Vec<Field>,
+ pub arg_types: Vec<DataType>,
/// Return type
pub return_type: DataType,
/// UDF implementation
@@ -61,7 +61,7 @@ impl Debug for ScalarFunction {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("ScalarFunction")
.field("name", &self.name)
- .field("args", &self.args)
+ .field("arg_types", &self.arg_types)
.field("return_type", &self.return_type)
.field("fun", &"<FUNC>")
.finish()
@@ -72,13 +72,13 @@ impl ScalarFunction {
/// Create a new ScalarFunction
pub fn new(
name: &str,
- args: Vec<Field>,
+ arg_types: Vec<DataType>,
return_type: DataType,
fun: ScalarUdf,
) -> Self {
Self {
name: name.to_owned(),
- args,
+ arg_types,
return_type,
fun,
}
diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs
index 22d49b5..65940ba 100644
--- a/rust/datafusion/src/optimizer/type_coercion.rs
+++ b/rust/datafusion/src/optimizer/type_coercion.rs
@@ -69,9 +69,8 @@ where
match self.scalar_functions.lookup(name) {
Some(func_meta) => {
for i in 0..expressions.len() {
- let field = &func_meta.args[i];
let actual_type = expressions[i].get_type(schema)?;
- let required_type = field.data_type();
+ let required_type = &func_meta.arg_types[i];
if &actual_type != required_type {
// attempt to coerce using numerical coercion
// todo: also try string coercion.
diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs
index 8f366de..f627d05 100644
--- a/rust/datafusion/src/sql/planner.rs
+++ b/rust/datafusion/src/sql/planner.rs
@@ -520,10 +520,8 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> {
let mut safe_args: Vec<Expr> = vec![];
for i in 0..rex_args.len() {
- safe_args.push(
- rex_args[i]
- .cast_to(fm.args[i].data_type(), schema)?,
- );
+ safe_args
+ .push(rex_args[i].cast_to(&fm.arg_types[i], schema)?);
}
Ok(Expr::ScalarFunction {
@@ -908,7 +906,7 @@ mod tests {
match name {
"sqrt" => Some(Arc::new(ScalarFunction::new(
"sqrt",
- vec![Field::new("n", DataType::Float64, false)],
+ vec![DataType::Float64],
DataType::Float64,
Arc::new(|_| Err(ExecutionError::NotImplemented("".to_string()))),
))),
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index 43ac27a..15120d7 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -220,7 +220,7 @@ fn create_ctx() -> Result<ExecutionContext> {
// register a custom UDF
ctx.register_udf(ScalarFunction::new(
"custom_sqrt",
- vec![Field::new("n", DataType::Float64, true)],
+ vec![DataType::Float64],
DataType::Float64,
Arc::new(custom_sqrt),
));