You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/07/27 19:33:39 UTC
[arrow-datafusion] branch master updated: add Atan2 (#2942)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 176f4329d add Atan2 (#2942)
176f4329d is described below
commit 176f4329dad5800c2f0c29edd21086f899bef676
Author: Wei-Ting Kuo <wa...@gmail.com>
AuthorDate: Thu Jul 28 03:33:34 2022 +0800
add Atan2 (#2942)
* add atan -> f64
* make atan2 support f32
* add test case for null input
* add math in mod.rs
* fix proto
* add sql test for atan2
* add text case in math_expressions
* cargo fmt
* fix error from clippy
* remove useless comment
* apply cargo fmt
---
datafusion/core/src/logical_plan/mod.rs | 4 +-
datafusion/core/tests/sql/expr.rs | 3 ++
datafusion/core/tests/sql/math.rs | 57 ++++++++++++++++++++
datafusion/core/tests/sql/mod.rs | 1 +
datafusion/expr/src/built_in_function.rs | 4 ++
datafusion/expr/src/expr_fn.rs | 2 +
datafusion/expr/src/function.rs | 12 +++++
datafusion/physical-expr/src/functions.rs | 3 ++
datafusion/physical-expr/src/math_expressions.rs | 69 +++++++++++++++++++++++-
datafusion/proto/proto/datafusion.proto | 1 +
datafusion/proto/src/from_proto.rs | 11 ++--
datafusion/proto/src/to_proto.rs | 1 +
12 files changed, 162 insertions(+), 6 deletions(-)
diff --git a/datafusion/core/src/logical_plan/mod.rs b/datafusion/core/src/logical_plan/mod.rs
index e4e26ad54..9b3919837 100644
--- a/datafusion/core/src/logical_plan/mod.rs
+++ b/datafusion/core/src/logical_plan/mod.rs
@@ -28,8 +28,8 @@ pub use datafusion_common::{
};
pub use datafusion_expr::{
abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan,
- avg, bit_length, btrim, call_fn, case, ceil, character_length, chr, coalesce, col,
- combine_filters, concat, concat_expr, concat_ws, concat_ws_expr, cos, count,
+ atan2, avg, bit_length, btrim, call_fn, case, ceil, character_length, chr, coalesce,
+ col, combine_filters, concat, concat_expr, concat_ws, concat_ws_expr, cos, count,
count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, exists, exp,
expr_rewriter,
expr_rewriter::{
diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs
index 93347ee41..c9c5d955a 100644
--- a/datafusion/core/tests/sql/expr.rs
+++ b/datafusion/core/tests/sql/expr.rs
@@ -505,6 +505,9 @@ async fn test_mathematical_expressions_with_null() -> Result<()> {
test_expression!("power(NULL, 2)", "NULL");
test_expression!("power(NULL, NULL)", "NULL");
test_expression!("power(2, NULL)", "NULL");
+ test_expression!("atan2(NULL, NULL)", "NULL");
+ test_expression!("atan2(1, NULL)", "NULL");
+ test_expression!("atan2(NULL, 1)", "NULL");
Ok(())
}
diff --git a/datafusion/core/tests/sql/math.rs b/datafusion/core/tests/sql/math.rs
new file mode 100644
index 000000000..cff7120a2
--- /dev/null
+++ b/datafusion/core/tests/sql/math.rs
@@ -0,0 +1,57 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::*;
+use arrow::array::Float64Array;
+
+#[tokio::test]
+async fn test_atan2() -> Result<()> {
+ let ctx = SessionContext::new();
+
+ let t1_schema = Arc::new(Schema::new(vec![
+ Field::new("x", DataType::Float64, true),
+ Field::new("y", DataType::Float64, true),
+ ]));
+
+ let t1_data = RecordBatch::try_new(
+ t1_schema.clone(),
+ vec![
+ Arc::new(Float64Array::from(vec![1.0, 1.0, -1.0, -1.0])),
+ Arc::new(Float64Array::from(vec![2.0, -2.0, 2.0, -2.0])),
+ ],
+ )?;
+ let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?;
+ ctx.register_table("t1", Arc::new(t1_table))?;
+
+ let sql = "SELECT atan2(y, x) FROM t1";
+ let actual = execute_to_batches(&ctx, sql).await;
+
+ let expected = vec![
+ "+---------------------+",
+ "| atan2(t1.y,t1.x) |",
+ "+---------------------+",
+ "| 1.1071487177940904 |",
+ "| -1.1071487177940904 |",
+ "| 2.0344439357957027 |",
+ "| -2.0344439357957027 |",
+ "+---------------------+",
+ ];
+
+ assert_batches_eq!(expected, &actual);
+
+ Ok(())
+}
diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs
index 7f235b1ba..f4153757f 100644
--- a/datafusion/core/tests/sql/mod.rs
+++ b/datafusion/core/tests/sql/mod.rs
@@ -92,6 +92,7 @@ pub mod intersection;
pub mod joins;
pub mod json;
pub mod limit;
+pub mod math;
pub mod order;
pub mod parquet;
pub mod predicates;
diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs
index 663888e2e..ffac07ca5 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -34,6 +34,8 @@ pub enum BuiltinScalarFunction {
Asin,
/// atan
Atan,
+ /// atan2
+ Atan2,
/// ceil
Ceil,
/// coalesce
@@ -181,6 +183,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Acos => Volatility::Immutable,
BuiltinScalarFunction::Asin => Volatility::Immutable,
BuiltinScalarFunction::Atan => Volatility::Immutable,
+ BuiltinScalarFunction::Atan2 => Volatility::Immutable,
BuiltinScalarFunction::Ceil => Volatility::Immutable,
BuiltinScalarFunction::Coalesce => Volatility::Immutable,
BuiltinScalarFunction::Cos => Volatility::Immutable,
@@ -268,6 +271,7 @@ impl FromStr for BuiltinScalarFunction {
"acos" => BuiltinScalarFunction::Acos,
"asin" => BuiltinScalarFunction::Asin,
"atan" => BuiltinScalarFunction::Atan,
+ "atan2" => BuiltinScalarFunction::Atan2,
"ceil" => BuiltinScalarFunction::Ceil,
"cos" => BuiltinScalarFunction::Cos,
"exp" => BuiltinScalarFunction::Exp,
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index abfd37a7c..97bbd419e 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -304,6 +304,7 @@ unary_scalar_expr!(Log10, log10);
unary_scalar_expr!(Ln, ln);
unary_scalar_expr!(NullIf, nullif);
scalar_expr!(Power, power, base, exponent);
+scalar_expr!(Atan2, atan2, y, x);
// string functions
scalar_expr!(Ascii, ascii, string);
@@ -546,6 +547,7 @@ mod test {
test_unary_scalar_expr!(Log2, log2);
test_unary_scalar_expr!(Log10, log10);
test_unary_scalar_expr!(Ln, ln);
+ test_scalar_expr!(Atan2, atan2, y, x);
test_scalar_expr!(Ascii, ascii, input);
test_scalar_expr!(BitLength, bit_length, string);
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index 331756f8d..29158e234 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -229,6 +229,11 @@ pub fn return_type(
BuiltinScalarFunction::Struct => Ok(DataType::Struct(vec![])),
+ BuiltinScalarFunction::Atan2 => match &input_expr_types[0] {
+ DataType::Float32 => Ok(DataType::Float32),
+ _ => Ok(DataType::Float64),
+ },
+
BuiltinScalarFunction::Abs
| BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
@@ -540,6 +545,13 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature {
],
fun.volatility(),
),
+ BuiltinScalarFunction::Atan2 => Signature::one_of(
+ vec![
+ TypeSignature::Exact(vec![DataType::Float32, DataType::Float32]),
+ TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]),
+ ],
+ fun.volatility(),
+ ),
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
// return the best approximation for it (in f64).
diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs
index 5f0e711f8..a84b00bf1 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -308,6 +308,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::Power => {
Arc::new(|args| make_scalar_function(math_expressions::power)(args))
}
+ BuiltinScalarFunction::Atan2 => {
+ Arc::new(|args| make_scalar_function(math_expressions::atan2)(args))
+ }
// string functions
BuiltinScalarFunction::Array => Arc::new(array_expressions::array),
diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs
index 7f4126815..16dda93dd 100644
--- a/datafusion/physical-expr/src/math_expressions.rs
+++ b/datafusion/physical-expr/src/math_expressions.rs
@@ -176,11 +176,38 @@ pub fn power(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}
+pub fn atan2(args: &[ArrayRef]) -> Result<ArrayRef> {
+ match args[0].data_type() {
+ DataType::Float64 => Ok(Arc::new(make_function_inputs2!(
+ &args[0],
+ &args[1],
+ "y",
+ "x",
+ Float64Array,
+ { f64::atan2 }
+ )) as ArrayRef),
+
+ DataType::Float32 => Ok(Arc::new(make_function_inputs2!(
+ &args[0],
+ &args[1],
+ "y",
+ "x",
+ Float32Array,
+ { f32::atan2 }
+ )) as ArrayRef),
+
+ other => Err(DataFusionError::Internal(format!(
+ "Unsupported data type {:?} for function atan2",
+ other
+ ))),
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
- use arrow::array::{Float64Array, NullArray};
+ use arrow::array::{Array, Float64Array, NullArray};
#[test]
fn test_random_expression() {
@@ -191,4 +218,44 @@ mod tests {
assert_eq!(floats.len(), 1);
assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0);
}
+
+ #[test]
+ fn test_atan2_f64() {
+ let args: Vec<ArrayRef> = vec![
+ Arc::new(Float64Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y
+ Arc::new(Float64Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x
+ ];
+
+ let result = atan2(&args).expect("fail");
+ let floats = result
+ .as_any()
+ .downcast_ref::<Float64Array>()
+ .expect("fail");
+
+ assert_eq!(floats.len(), 4);
+ assert_eq!(floats.value(0), (2.0_f64).atan2(1.0));
+ assert_eq!(floats.value(1), (-3.0_f64).atan2(2.0));
+ assert_eq!(floats.value(2), (4.0_f64).atan2(-3.0));
+ assert_eq!(floats.value(3), (-5.0_f64).atan2(-4.0));
+ }
+
+ #[test]
+ fn test_atan2_f32() {
+ let args: Vec<ArrayRef> = vec![
+ Arc::new(Float32Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y
+ Arc::new(Float32Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x
+ ];
+
+ let result = atan2(&args).expect("fail");
+ let floats = result
+ .as_any()
+ .downcast_ref::<Float32Array>()
+ .expect("fail");
+
+ assert_eq!(floats.len(), 4);
+ assert_eq!(floats.value(0), (2.0_f32).atan2(1.0));
+ assert_eq!(floats.value(1), (-3.0_f32).atan2(2.0));
+ assert_eq!(floats.value(2), (4.0_f32).atan2(-3.0));
+ assert_eq!(floats.value(3), (-5.0_f32).atan2(-4.0));
+ }
}
diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto
index 39c254ea7..ec816a419 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -439,6 +439,7 @@ enum ScalarFunction {
Power=64;
StructFun=65;
FromUnixtime=66;
+ Atan2=67;
}
message ScalarFunctionNode {
diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs
index cb7b11189..40ea1bd02 100644
--- a/datafusion/proto/src/from_proto.rs
+++ b/datafusion/proto/src/from_proto.rs
@@ -32,9 +32,9 @@ use datafusion_common::{
use datafusion_expr::expr::GroupingSet;
use datafusion_expr::expr::GroupingSet::GroupingSets;
use datafusion_expr::{
- abs, acos, array, ascii, asin, atan, bit_length, btrim, ceil, character_length, chr,
- coalesce, concat_expr, concat_ws_expr, cos, date_part, date_trunc, digest, exp,
- floor, from_unixtime, left, ln, log10, log2,
+ abs, acos, array, ascii, asin, atan, atan2, bit_length, btrim, ceil,
+ character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, date_part,
+ date_trunc, digest, exp, floor, from_unixtime, left, ln, log10, log2,
logical_plan::{PlanType, StringifiedPlan},
lower, lpad, ltrim, md5, now_expr, nullif, octet_length, power, random, regexp_match,
regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256,
@@ -474,6 +474,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::Power => Self::Power,
ScalarFunction::StructFun => Self::Struct,
ScalarFunction::FromUnixtime => Self::FromUnixtime,
+ ScalarFunction::Atan2 => Self::Atan2,
}
}
}
@@ -1132,6 +1133,10 @@ pub fn parse_expr(
ScalarFunction::FromUnixtime => {
Ok(from_unixtime(parse_expr(&args[0], registry)?))
}
+ ScalarFunction::Atan2 => Ok(atan2(
+ parse_expr(&args[0], registry)?,
+ parse_expr(&args[1], registry)?,
+ )),
_ => Err(proto_error(
"Protobuf deserialization error: Unsupported scalar function",
)),
diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs
index 60f4079da..323e2186d 100644
--- a/datafusion/proto/src/to_proto.rs
+++ b/datafusion/proto/src/to_proto.rs
@@ -1124,6 +1124,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::Power => Self::Power,
BuiltinScalarFunction::Struct => Self::StructFun,
BuiltinScalarFunction::FromUnixtime => Self::FromUnixtime,
+ BuiltinScalarFunction::Atan2 => Self::Atan2,
};
Ok(scalar_function)