You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/07/27 19:33:39 UTC

[arrow-datafusion] branch master updated: add Atan2 (#2942)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 176f4329d add Atan2 (#2942)
176f4329d is described below

commit 176f4329dad5800c2f0c29edd21086f899bef676
Author: Wei-Ting Kuo <wa...@gmail.com>
AuthorDate: Thu Jul 28 03:33:34 2022 +0800

    add Atan2 (#2942)
    
    * add atan -> f64
    
    * make atan2 support f32
    
    * add test case for null input
    
    * add math in mod.rs
    
    * fix proto
    
    * add sql test for atan2
    
    * add text case in math_expressions
    
    * cargo fmt
    
    * fix error from clippy
    
    * remove useless comment
    
    * apply cargo fmt
---
 datafusion/core/src/logical_plan/mod.rs          |  4 +-
 datafusion/core/tests/sql/expr.rs                |  3 ++
 datafusion/core/tests/sql/math.rs                | 57 ++++++++++++++++++++
 datafusion/core/tests/sql/mod.rs                 |  1 +
 datafusion/expr/src/built_in_function.rs         |  4 ++
 datafusion/expr/src/expr_fn.rs                   |  2 +
 datafusion/expr/src/function.rs                  | 12 +++++
 datafusion/physical-expr/src/functions.rs        |  3 ++
 datafusion/physical-expr/src/math_expressions.rs | 69 +++++++++++++++++++++++-
 datafusion/proto/proto/datafusion.proto          |  1 +
 datafusion/proto/src/from_proto.rs               | 11 ++--
 datafusion/proto/src/to_proto.rs                 |  1 +
 12 files changed, 162 insertions(+), 6 deletions(-)

diff --git a/datafusion/core/src/logical_plan/mod.rs b/datafusion/core/src/logical_plan/mod.rs
index e4e26ad54..9b3919837 100644
--- a/datafusion/core/src/logical_plan/mod.rs
+++ b/datafusion/core/src/logical_plan/mod.rs
@@ -28,8 +28,8 @@ pub use datafusion_common::{
 };
 pub use datafusion_expr::{
     abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan,
-    avg, bit_length, btrim, call_fn, case, ceil, character_length, chr, coalesce, col,
-    combine_filters, concat, concat_expr, concat_ws, concat_ws_expr, cos, count,
+    atan2, avg, bit_length, btrim, call_fn, case, ceil, character_length, chr, coalesce,
+    col, combine_filters, concat, concat_expr, concat_ws, concat_ws_expr, cos, count,
     count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, exists, exp,
     expr_rewriter,
     expr_rewriter::{
diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs
index 93347ee41..c9c5d955a 100644
--- a/datafusion/core/tests/sql/expr.rs
+++ b/datafusion/core/tests/sql/expr.rs
@@ -505,6 +505,9 @@ async fn test_mathematical_expressions_with_null() -> Result<()> {
     test_expression!("power(NULL, 2)", "NULL");
     test_expression!("power(NULL, NULL)", "NULL");
     test_expression!("power(2, NULL)", "NULL");
+    test_expression!("atan2(NULL, NULL)", "NULL");
+    test_expression!("atan2(1, NULL)", "NULL");
+    test_expression!("atan2(NULL, 1)", "NULL");
     Ok(())
 }
 
diff --git a/datafusion/core/tests/sql/math.rs b/datafusion/core/tests/sql/math.rs
new file mode 100644
index 000000000..cff7120a2
--- /dev/null
+++ b/datafusion/core/tests/sql/math.rs
@@ -0,0 +1,57 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::*;
+use arrow::array::Float64Array;
+
+#[tokio::test]
+async fn test_atan2() -> Result<()> {
+    let ctx = SessionContext::new();
+
+    let t1_schema = Arc::new(Schema::new(vec![
+        Field::new("x", DataType::Float64, true),
+        Field::new("y", DataType::Float64, true),
+    ]));
+
+    let t1_data = RecordBatch::try_new(
+        t1_schema.clone(),
+        vec![
+            Arc::new(Float64Array::from(vec![1.0, 1.0, -1.0, -1.0])),
+            Arc::new(Float64Array::from(vec![2.0, -2.0, 2.0, -2.0])),
+        ],
+    )?;
+    let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?;
+    ctx.register_table("t1", Arc::new(t1_table))?;
+
+    let sql = "SELECT atan2(y, x) FROM t1";
+    let actual = execute_to_batches(&ctx, sql).await;
+
+    let expected = vec![
+        "+---------------------+",
+        "| atan2(t1.y,t1.x)    |",
+        "+---------------------+",
+        "| 1.1071487177940904  |",
+        "| -1.1071487177940904 |",
+        "| 2.0344439357957027  |",
+        "| -2.0344439357957027 |",
+        "+---------------------+",
+    ];
+
+    assert_batches_eq!(expected, &actual);
+
+    Ok(())
+}
diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs
index 7f235b1ba..f4153757f 100644
--- a/datafusion/core/tests/sql/mod.rs
+++ b/datafusion/core/tests/sql/mod.rs
@@ -92,6 +92,7 @@ pub mod intersection;
 pub mod joins;
 pub mod json;
 pub mod limit;
+pub mod math;
 pub mod order;
 pub mod parquet;
 pub mod predicates;
diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs
index 663888e2e..ffac07ca5 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -34,6 +34,8 @@ pub enum BuiltinScalarFunction {
     Asin,
     /// atan
     Atan,
+    /// atan2
+    Atan2,
     /// ceil
     Ceil,
     /// coalesce
@@ -181,6 +183,7 @@ impl BuiltinScalarFunction {
             BuiltinScalarFunction::Acos => Volatility::Immutable,
             BuiltinScalarFunction::Asin => Volatility::Immutable,
             BuiltinScalarFunction::Atan => Volatility::Immutable,
+            BuiltinScalarFunction::Atan2 => Volatility::Immutable,
             BuiltinScalarFunction::Ceil => Volatility::Immutable,
             BuiltinScalarFunction::Coalesce => Volatility::Immutable,
             BuiltinScalarFunction::Cos => Volatility::Immutable,
@@ -268,6 +271,7 @@ impl FromStr for BuiltinScalarFunction {
             "acos" => BuiltinScalarFunction::Acos,
             "asin" => BuiltinScalarFunction::Asin,
             "atan" => BuiltinScalarFunction::Atan,
+            "atan2" => BuiltinScalarFunction::Atan2,
             "ceil" => BuiltinScalarFunction::Ceil,
             "cos" => BuiltinScalarFunction::Cos,
             "exp" => BuiltinScalarFunction::Exp,
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index abfd37a7c..97bbd419e 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -304,6 +304,7 @@ unary_scalar_expr!(Log10, log10);
 unary_scalar_expr!(Ln, ln);
 unary_scalar_expr!(NullIf, nullif);
 scalar_expr!(Power, power, base, exponent);
+scalar_expr!(Atan2, atan2, y, x);
 
 // string functions
 scalar_expr!(Ascii, ascii, string);
@@ -546,6 +547,7 @@ mod test {
         test_unary_scalar_expr!(Log2, log2);
         test_unary_scalar_expr!(Log10, log10);
         test_unary_scalar_expr!(Ln, ln);
+        test_scalar_expr!(Atan2, atan2, y, x);
 
         test_scalar_expr!(Ascii, ascii, input);
         test_scalar_expr!(BitLength, bit_length, string);
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index 331756f8d..29158e234 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -229,6 +229,11 @@ pub fn return_type(
 
         BuiltinScalarFunction::Struct => Ok(DataType::Struct(vec![])),
 
+        BuiltinScalarFunction::Atan2 => match &input_expr_types[0] {
+            DataType::Float32 => Ok(DataType::Float32),
+            _ => Ok(DataType::Float64),
+        },
+
         BuiltinScalarFunction::Abs
         | BuiltinScalarFunction::Acos
         | BuiltinScalarFunction::Asin
@@ -540,6 +545,13 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature {
             ],
             fun.volatility(),
         ),
+        BuiltinScalarFunction::Atan2 => Signature::one_of(
+            vec![
+                TypeSignature::Exact(vec![DataType::Float32, DataType::Float32]),
+                TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]),
+            ],
+            fun.volatility(),
+        ),
         // math expressions expect 1 argument of type f64 or f32
         // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
         // return the best approximation for it (in f64).
diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs
index 5f0e711f8..a84b00bf1 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -308,6 +308,9 @@ pub fn create_physical_fun(
         BuiltinScalarFunction::Power => {
             Arc::new(|args| make_scalar_function(math_expressions::power)(args))
         }
+        BuiltinScalarFunction::Atan2 => {
+            Arc::new(|args| make_scalar_function(math_expressions::atan2)(args))
+        }
 
         // string functions
         BuiltinScalarFunction::Array => Arc::new(array_expressions::array),
diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs
index 7f4126815..16dda93dd 100644
--- a/datafusion/physical-expr/src/math_expressions.rs
+++ b/datafusion/physical-expr/src/math_expressions.rs
@@ -176,11 +176,38 @@ pub fn power(args: &[ArrayRef]) -> Result<ArrayRef> {
     }
 }
 
+pub fn atan2(args: &[ArrayRef]) -> Result<ArrayRef> {
+    match args[0].data_type() {
+        DataType::Float64 => Ok(Arc::new(make_function_inputs2!(
+            &args[0],
+            &args[1],
+            "y",
+            "x",
+            Float64Array,
+            { f64::atan2 }
+        )) as ArrayRef),
+
+        DataType::Float32 => Ok(Arc::new(make_function_inputs2!(
+            &args[0],
+            &args[1],
+            "y",
+            "x",
+            Float32Array,
+            { f32::atan2 }
+        )) as ArrayRef),
+
+        other => Err(DataFusionError::Internal(format!(
+            "Unsupported data type {:?} for function atan2",
+            other
+        ))),
+    }
+}
+
 #[cfg(test)]
 mod tests {
 
     use super::*;
-    use arrow::array::{Float64Array, NullArray};
+    use arrow::array::{Array, Float64Array, NullArray};
 
     #[test]
     fn test_random_expression() {
@@ -191,4 +218,44 @@ mod tests {
         assert_eq!(floats.len(), 1);
         assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0);
     }
+
+    #[test]
+    fn test_atan2_f64() {
+        let args: Vec<ArrayRef> = vec![
+            Arc::new(Float64Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y
+            Arc::new(Float64Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x
+        ];
+
+        let result = atan2(&args).expect("fail");
+        let floats = result
+            .as_any()
+            .downcast_ref::<Float64Array>()
+            .expect("fail");
+
+        assert_eq!(floats.len(), 4);
+        assert_eq!(floats.value(0), (2.0_f64).atan2(1.0));
+        assert_eq!(floats.value(1), (-3.0_f64).atan2(2.0));
+        assert_eq!(floats.value(2), (4.0_f64).atan2(-3.0));
+        assert_eq!(floats.value(3), (-5.0_f64).atan2(-4.0));
+    }
+
+    #[test]
+    fn test_atan2_f32() {
+        let args: Vec<ArrayRef> = vec![
+            Arc::new(Float32Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y
+            Arc::new(Float32Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x
+        ];
+
+        let result = atan2(&args).expect("fail");
+        let floats = result
+            .as_any()
+            .downcast_ref::<Float32Array>()
+            .expect("fail");
+
+        assert_eq!(floats.len(), 4);
+        assert_eq!(floats.value(0), (2.0_f32).atan2(1.0));
+        assert_eq!(floats.value(1), (-3.0_f32).atan2(2.0));
+        assert_eq!(floats.value(2), (4.0_f32).atan2(-3.0));
+        assert_eq!(floats.value(3), (-5.0_f32).atan2(-4.0));
+    }
 }
diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto
index 39c254ea7..ec816a419 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -439,6 +439,7 @@ enum ScalarFunction {
   Power=64;
   StructFun=65;
   FromUnixtime=66;
+  Atan2=67;
 }
 
 message ScalarFunctionNode {
diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs
index cb7b11189..40ea1bd02 100644
--- a/datafusion/proto/src/from_proto.rs
+++ b/datafusion/proto/src/from_proto.rs
@@ -32,9 +32,9 @@ use datafusion_common::{
 use datafusion_expr::expr::GroupingSet;
 use datafusion_expr::expr::GroupingSet::GroupingSets;
 use datafusion_expr::{
-    abs, acos, array, ascii, asin, atan, bit_length, btrim, ceil, character_length, chr,
-    coalesce, concat_expr, concat_ws_expr, cos, date_part, date_trunc, digest, exp,
-    floor, from_unixtime, left, ln, log10, log2,
+    abs, acos, array, ascii, asin, atan, atan2, bit_length, btrim, ceil,
+    character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, date_part,
+    date_trunc, digest, exp, floor, from_unixtime, left, ln, log10, log2,
     logical_plan::{PlanType, StringifiedPlan},
     lower, lpad, ltrim, md5, now_expr, nullif, octet_length, power, random, regexp_match,
     regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256,
@@ -474,6 +474,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
             ScalarFunction::Power => Self::Power,
             ScalarFunction::StructFun => Self::Struct,
             ScalarFunction::FromUnixtime => Self::FromUnixtime,
+            ScalarFunction::Atan2 => Self::Atan2,
         }
     }
 }
@@ -1132,6 +1133,10 @@ pub fn parse_expr(
                 ScalarFunction::FromUnixtime => {
                     Ok(from_unixtime(parse_expr(&args[0], registry)?))
                 }
+                ScalarFunction::Atan2 => Ok(atan2(
+                    parse_expr(&args[0], registry)?,
+                    parse_expr(&args[1], registry)?,
+                )),
                 _ => Err(proto_error(
                     "Protobuf deserialization error: Unsupported scalar function",
                 )),
diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs
index 60f4079da..323e2186d 100644
--- a/datafusion/proto/src/to_proto.rs
+++ b/datafusion/proto/src/to_proto.rs
@@ -1124,6 +1124,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
             BuiltinScalarFunction::Power => Self::Power,
             BuiltinScalarFunction::Struct => Self::StructFun,
             BuiltinScalarFunction::FromUnixtime => Self::FromUnixtime,
+            BuiltinScalarFunction::Atan2 => Self::Atan2,
         };
 
         Ok(scalar_function)