You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ko...@apache.org on 2023/01/02 13:36:29 UTC

[arrow] branch master updated: ARROW-17144: [C++][Gandiva] Add sqrt function (#13656)

This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 1e8ca94fc3 ARROW-17144: [C++][Gandiva] Add sqrt function (#13656)
1e8ca94fc3 is described below

commit 1e8ca94fc3682eb97bcf243545dcb282c1aaa0b4
Author: Sahaj Gupta <10...@users.noreply.github.com>
AuthorDate: Mon Jan 2 19:06:18 2023 +0530

    ARROW-17144: [C++][Gandiva] Add sqrt function (#13656)
    
    Lead-authored-by: SG011 <sa...@dremio.com>
    Co-authored-by: Sutou Kouhei <ko...@clear-code.com>
    Signed-off-by: Sutou Kouhei <ko...@clear-code.com>
---
 cpp/src/gandiva/function_registry_arithmetic.cc    |  5 ++
 cpp/src/gandiva/precompiled/arithmetic_ops.cc      | 15 ++++
 cpp/src/gandiva/precompiled/arithmetic_ops_test.cc | 29 ++++++++
 cpp/src/gandiva/precompiled/types.h                |  4 ++
 cpp/src/gandiva/tests/projector_test.cc            | 81 ++++++++++++++++++++++
 5 files changed, 134 insertions(+)

diff --git a/cpp/src/gandiva/function_registry_arithmetic.cc b/cpp/src/gandiva/function_registry_arithmetic.cc
index 0c1b64f61a..320dd5ded1 100644
--- a/cpp/src/gandiva/function_registry_arithmetic.cc
+++ b/cpp/src/gandiva/function_registry_arithmetic.cc
@@ -203,6 +203,11 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
       // floor functions
       UNARY_SAFE_NULL_IF_NULL(floor, {}, float32, float32),
       UNARY_SAFE_NULL_IF_NULL(floor, {}, float64, float64),
+      // sqrt functions
+      UNARY_SAFE_NULL_IF_NULL(sqrt, {}, int32, float64),
+      UNARY_SAFE_NULL_IF_NULL(sqrt, {}, int64, float64),
+      UNARY_SAFE_NULL_IF_NULL(sqrt, {}, float32, float64),
+      UNARY_SAFE_NULL_IF_NULL(sqrt, {}, float64, float64),
 
       // compare functions
       BINARY_RELATIONAL_BOOL_FN(equal, ({"eq", "same"})),
diff --git a/cpp/src/gandiva/precompiled/arithmetic_ops.cc b/cpp/src/gandiva/precompiled/arithmetic_ops.cc
index f186f24b9d..b4959e9d7a 100644
--- a/cpp/src/gandiva/precompiled/arithmetic_ops.cc
+++ b/cpp/src/gandiva/precompiled/arithmetic_ops.cc
@@ -511,6 +511,21 @@ FLOOR(float32)
 FLOOR(float64)
 
 #undef FLOOR
+#define SQRT(TYPE)                              \
+  FORCE_INLINE                                  \
+  gdv_float64 sqrt_##TYPE(gdv_##TYPE in1) {     \
+    if (in1 < 0) {                              \
+      return NAN;                               \
+    }                                           \
+    return static_cast<gdv_float64>(sqrt(in1)); \
+  }
+
+SQRT(int32)
+SQRT(int64)
+SQRT(float32)
+SQRT(float64)
+
+#undef SQRT
 
 #undef NUMERIC_FUNCTION
 #undef NUMERIC_TYPES
diff --git a/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc b/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc
index 77e1d65f3b..c0de758215 100644
--- a/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc
+++ b/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc
@@ -669,4 +669,33 @@ TEST(TestArithmeticOps, TestFloorFloatDouble) {
   EXPECT_EQ(floor_float64(-2147483647), -2147483647.0);
 }
 
+TEST(TestArithmeticOps, TestSqrtIntFloatDouble) {
+  // sqrt from int32
+  EXPECT_EQ(sqrt_int32(36), 6.0);
+  EXPECT_EQ(sqrt_int32(49), 7.0);
+  EXPECT_EQ(sqrt_int32(64), 8.0);
+  EXPECT_EQ(sqrt_int32(81), 9.0);
+
+  // sqrt from int64
+  EXPECT_EQ(sqrt_int64(4), 2.0);
+  EXPECT_EQ(sqrt_int64(9), 3.0);
+  EXPECT_EQ(sqrt_int64(64), 8.0);
+  EXPECT_EQ(sqrt_int64(81), 9.0);
+
+  // sqrt from floats
+  EXPECT_EQ(sqrt_float32(16.0f), 4.0);
+  EXPECT_EQ(sqrt_float32(49.0f), 7.0);
+  EXPECT_EQ(sqrt_float32(36.0f), 6.0);
+  EXPECT_EQ(sqrt_float32(0.0f), 0.0);
+
+  // sqrt from doubles
+  EXPECT_EQ(sqrt_float64(16.0), 4.0);
+  EXPECT_EQ(sqrt_float64(11.0889), 3.33);
+  EXPECT_EQ(sqrt_float64(1.522756), 1.234);
+  EXPECT_EQ(sqrt_float64(49.0), 7.0);
+  EXPECT_EQ(sqrt_float64(36.0), 6.0);
+  EXPECT_EQ(sqrt_float64(0.0), 0.0);
+  EXPECT_TRUE(std::isnan(sqrt_float64(-1.0)));
+}
+
 }  // namespace gandiva
diff --git a/cpp/src/gandiva/precompiled/types.h b/cpp/src/gandiva/precompiled/types.h
index 93545778d3..3855f64fd6 100644
--- a/cpp/src/gandiva/precompiled/types.h
+++ b/cpp/src/gandiva/precompiled/types.h
@@ -235,6 +235,10 @@ gdv_float32 ceiling_float32(gdv_float32 in);
 gdv_float64 ceiling_float64(gdv_float64 in);
 gdv_float32 floor_float32(gdv_float32 in);
 gdv_float64 floor_float64(gdv_float64 in);
+gdv_float64 sqrt_int32(gdv_int32 in);
+gdv_float64 sqrt_int64(gdv_int64 in);
+gdv_float64 sqrt_float32(gdv_float32 in);
+gdv_float64 sqrt_float64(gdv_float64 in);
 
 gdv_float32 round_float32(gdv_float32);
 gdv_float64 round_float64(gdv_float64);
diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc
index 0191f5ebaa..fd2f2a12af 100644
--- a/cpp/src/gandiva/tests/projector_test.cc
+++ b/cpp/src/gandiva/tests/projector_test.cc
@@ -3419,4 +3419,85 @@ TEST_F(TestProjector, TestMaskDefault) {
   EXPECT_ARROW_ARRAY_EQUALS(exp_mask, outputs.at(0));
 }
 
+TEST_F(TestProjector, TestSqrtInt32) {
+  auto in_field = field("in", arrow::int32());
+  auto schema = arrow::schema({in_field});
+  auto out_field = field("out", arrow::float64());
+  auto sqrt = TreeExprBuilder::MakeExpression("sqrt", {in_field}, out_field);
+
+  std::shared_ptr<Projector> projector;
+  ARROW_EXPECT_OK(Projector::Make(schema, {sqrt}, TestConfiguration(), &projector));
+
+  int num_records = 4;
+  auto array = MakeArrowArrayInt32({1, 4, 9, 16}, {true, true, true, true});
+  auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array});
+  auto out = MakeArrowArrayFloat64({1.0, 2.0, 3.0, 4.0}, {true, true, true, true});
+
+  arrow::ArrayVector outs;
+  ARROW_EXPECT_OK(projector->Evaluate(*in_batch, pool_, &outs));
+
+  EXPECT_ARROW_ARRAY_EQUALS(out, outs.at(0));
+}
+
+TEST_F(TestProjector, TestSqrtInt64) {
+  auto in_field = field("in", arrow::int64());
+  auto schema = arrow::schema({in_field});
+  auto out_field = field("out", arrow::float64());
+  auto sqrt = TreeExprBuilder::MakeExpression("sqrt", {in_field}, out_field);
+
+  std::shared_ptr<Projector> projector;
+  ARROW_EXPECT_OK(Projector::Make(schema, {sqrt}, TestConfiguration(), &projector));
+
+  int num_records = 4;
+  auto array = MakeArrowArrayInt64({1, 9, 16, 25}, {true, true, true, true});
+  auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array});
+  auto out = MakeArrowArrayFloat64({1.0, 3.0, 4.0, 5.0}, {true, true, true, true});
+
+  arrow::ArrayVector outs;
+  ARROW_EXPECT_OK(projector->Evaluate(*in_batch, pool_, &outs));
+
+  EXPECT_ARROW_ARRAY_EQUALS(out, outs.at(0));
+}
+
+TEST_F(TestProjector, TestSqrtFloat32) {
+  auto in_field = field("in", arrow::float32());
+  auto schema = arrow::schema({in_field});
+  auto out_field = field("out", arrow::float64());
+  auto sqrt = TreeExprBuilder::MakeExpression("sqrt", {in_field}, out_field);
+
+  std::shared_ptr<Projector> projector;
+  ARROW_EXPECT_OK(Projector::Make(schema, {sqrt}, TestConfiguration(), &projector));
+
+  int num_records = 4;
+  auto array =
+      MakeArrowArrayFloat32({1.0f, 4.0f, 25.0f, 36.0f}, {true, true, true, true});
+  auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array});
+  auto out = MakeArrowArrayFloat64({1.0, 2.0, 5.0, 6.0}, {true, true, true, true});
+
+  arrow::ArrayVector outs;
+  ARROW_EXPECT_OK(projector->Evaluate(*in_batch, pool_, &outs));
+
+  EXPECT_ARROW_ARRAY_EQUALS(out, outs.at(0));
+}
+
+TEST_F(TestProjector, TestSqrtFloat64) {
+  auto in_field = field("in", arrow::float64());
+  auto schema = arrow::schema({in_field});
+  auto out_field = field("out", arrow::float64());
+  auto sqrt = TreeExprBuilder::MakeExpression("sqrt", {in_field}, out_field);
+
+  std::shared_ptr<Projector> projector;
+  ARROW_EXPECT_OK(Projector::Make(schema, {sqrt}, TestConfiguration(), &projector));
+
+  int num_records = 4;
+  auto array = MakeArrowArrayFloat64({1.0, 4.0, 9.0, 16.0}, {true, true, true, true});
+  auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array});
+  auto out = MakeArrowArrayFloat64({1.0, 2.0, 3.0, 4.0}, {true, true, true, true});
+
+  arrow::ArrayVector outs;
+  ARROW_EXPECT_OK(projector->Evaluate(*in_batch, pool_, &outs));
+
+  EXPECT_ARROW_ARRAY_EQUALS(out, outs.at(0));
+}
+
 }  // namespace gandiva