You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by pr...@apache.org on 2020/04/18 12:50:44 UTC

[arrow] branch master updated: ARROW-8443: [Gandiva][C++] Fix Trunc and Round output types.

This is an automated email from the ASF dual-hosted git repository.

praveenbingo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new ca7418f  ARROW-8443: [Gandiva][C++] Fix Trunc and Round output types.
ca7418f is described below

commit ca7418f0c2c2ae1083140fd6f1c5ea6f73b11a6a
Author: Praveen <pr...@dremio.com>
AuthorDate: Sat Apr 18 18:18:42 2020 +0530

    ARROW-8443: [Gandiva][C++] Fix Trunc and Round output types.
    
    - Changed trunc and round to honor output precision and scale.
    - Make it a no-op if there is no change in output scale and rounding scale is positive.
    
    Closes #6942 from praveenbingo/ARROW-8443-1 and squashes the following commits:
    
    c94f60642 <Praveen> ARROW-8443:  Fix Trunc and Round output types.
    
    Authored-by: Praveen <pr...@dremio.com>
    Signed-off-by: Praveen <pr...@dremio.com>
---
 cpp/src/gandiva/precompiled/decimal_ops.cc       |  33 ++++--
 cpp/src/gandiva/precompiled/decimal_ops.h        |   7 +-
 cpp/src/gandiva/precompiled/decimal_ops_test.cc  | 123 ++++++++++++++---------
 cpp/src/gandiva/precompiled/decimal_wrapper.cc   |  10 +-
 cpp/src/gandiva/precompiled/extended_math_ops.cc |   5 +-
 5 files changed, 112 insertions(+), 66 deletions(-)

diff --git a/cpp/src/gandiva/precompiled/decimal_ops.cc b/cpp/src/gandiva/precompiled/decimal_ops.cc
index 8b39dac..03b72de 100644
--- a/cpp/src/gandiva/precompiled/decimal_ops.cc
+++ b/cpp/src/gandiva/precompiled/decimal_ops.cc
@@ -658,24 +658,35 @@ static BasicDecimal128 RoundWithNegativeScale(const BasicDecimalScalar128& x,
   return scaled + delta;
 }
 
-BasicDecimal128 Round(const BasicDecimalScalar128& x, int32_t out_scale, bool* overflow) {
-  if (out_scale < 0) {
-    return RoundWithNegativeScale(x, x.precision(), out_scale,
+BasicDecimal128 Round(const BasicDecimalScalar128& x, int32_t out_precision,
+                      int32_t out_scale, int32_t rounding_scale, bool* overflow) {
+  // no-op if target scale is same as arg scale
+  if (x.scale() == out_scale && rounding_scale >= 0) {
+    return x.value();
+  }
+
+  if (rounding_scale < 0) {
+    return RoundWithNegativeScale(x, out_precision, rounding_scale,
                                   RoundType::kRoundTypeHalfRoundUp, overflow);
   } else {
-    return RoundWithPositiveScale(x, x.precision(), out_scale,
+    return RoundWithPositiveScale(x, out_precision, rounding_scale,
                                   RoundType::kRoundTypeHalfRoundUp, overflow);
   }
 }
 
-BasicDecimal128 Truncate(const BasicDecimalScalar128& x, int32_t out_scale,
-                         bool* overflow) {
-  if (out_scale < 0) {
-    return RoundWithNegativeScale(x, x.precision(), out_scale, RoundType::kRoundTypeTrunc,
-                                  overflow);
+BasicDecimal128 Truncate(const BasicDecimalScalar128& x, int32_t out_precision,
+                         int32_t out_scale, int32_t rounding_scale, bool* overflow) {
+  // no-op if target scale is same as arg scale
+  if (x.scale() == out_scale && rounding_scale >= 0) {
+    return x.value();
+  }
+
+  if (rounding_scale < 0) {
+    return RoundWithNegativeScale(x, out_precision, rounding_scale,
+                                  RoundType::kRoundTypeTrunc, overflow);
   } else {
-    return RoundWithPositiveScale(x, x.precision(), out_scale, RoundType::kRoundTypeTrunc,
-                                  overflow);
+    return RoundWithPositiveScale(x, out_precision, rounding_scale,
+                                  RoundType::kRoundTypeTrunc, overflow);
   }
 }
 
diff --git a/cpp/src/gandiva/precompiled/decimal_ops.h b/cpp/src/gandiva/precompiled/decimal_ops.h
index b342943..292dce2 100644
--- a/cpp/src/gandiva/precompiled/decimal_ops.h
+++ b/cpp/src/gandiva/precompiled/decimal_ops.h
@@ -73,11 +73,12 @@ BasicDecimal128 Convert(const BasicDecimalScalar128& x, int32_t out_precision,
                         int32_t out_scale, bool* overflow);
 
 /// round decimal.
-BasicDecimal128 Round(const BasicDecimalScalar128& x, int32_t out_scale, bool* overflow);
+BasicDecimal128 Round(const BasicDecimalScalar128& x, int32_t out_precision,
+                      int32_t out_scale, int32_t rounding_scale, bool* overflow);
 
 /// truncate decimal.
-BasicDecimal128 Truncate(const BasicDecimalScalar128& x, int32_t out_scale,
-                         bool* overflow);
+BasicDecimal128 Truncate(const BasicDecimalScalar128& x, int32_t out_precision,
+                         int32_t out_scale, int32_t rounding_scale, bool* overflow);
 
 /// ceil decimal
 BasicDecimal128 Ceil(const BasicDecimalScalar128& x, bool* overflow);
diff --git a/cpp/src/gandiva/precompiled/decimal_ops_test.cc b/cpp/src/gandiva/precompiled/decimal_ops_test.cc
index 3bc77a9..be8a1fe 100644
--- a/cpp/src/gandiva/precompiled/decimal_ops_test.cc
+++ b/cpp/src/gandiva/precompiled/decimal_ops_test.cc
@@ -552,54 +552,70 @@ TEST_F(TestDecimalSql, Compare) {
 
 TEST_F(TestDecimalSql, Round) {
   // expected, input, rounding_scale, overflow
-  using TupleType = std::tuple<BasicDecimal128, DecimalScalar128, int32_t, bool>;
+  using TupleType = std::tuple<DecimalScalar128, DecimalScalar128, int32_t, bool>;
   std::vector<TupleType> test_values = {
       // examples from
       // https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_round
-      std::make_tuple(BasicDecimal128{-1}, DecimalScalar128{-123, 38, 2}, 0, false),
-      std::make_tuple(BasicDecimal128{-2}, DecimalScalar128{-158, 38, 2}, 0, false),
-      std::make_tuple(BasicDecimal128{2}, DecimalScalar128{158, 38, 2}, 0, false),
-      std::make_tuple(BasicDecimal128{-13}, DecimalScalar128{-1298, 38, 3}, 1, false),
-      std::make_tuple(BasicDecimal128{-1}, DecimalScalar128{-1298, 38, 3}, 0, false),
-      std::make_tuple(BasicDecimal128{20}, DecimalScalar128{23298, 38, 3}, -1, false),
-      std::make_tuple(BasicDecimal128{3}, DecimalScalar128{25, 38, 1}, 0, false),
+      std::make_tuple(DecimalScalar128{-1, 36, 0}, DecimalScalar128{-123, 38, 2}, 0,
+                      false),
+      std::make_tuple(DecimalScalar128{-2, 36, 0}, DecimalScalar128{-158, 38, 2}, 0,
+                      false),
+      std::make_tuple(DecimalScalar128{2, 36, 0}, DecimalScalar128{158, 38, 2}, 0, false),
+      std::make_tuple(DecimalScalar128{-13, 36, 1}, DecimalScalar128{-1298, 38, 3}, 1,
+                      false),
+      std::make_tuple(DecimalScalar128{-1, 35, 0}, DecimalScalar128{-1298, 38, 3}, 0,
+                      false),
+      std::make_tuple(DecimalScalar128{20, 35, 0}, DecimalScalar128{23298, 38, 3}, -1,
+                      false),
+      std::make_tuple(DecimalScalar128{100, 38, 0}, DecimalScalar128{122, 38, 0}, -2,
+                      false),
+      std::make_tuple(DecimalScalar128{3, 37, 0}, DecimalScalar128{25, 38, 1}, 0, false),
 
       // border cases
-      std::make_tuple(BasicDecimal128{INT64_MIN / 100},
+      std::make_tuple(DecimalScalar128{INT64_MIN / 100, 36, 0},
                       DecimalScalar128{INT64_MIN, 38, 2}, 0, false),
 
-      std::make_tuple(INT64_MIN, DecimalScalar128{INT64_MIN, 38, 0}, 0, false),
-      std::make_tuple(BasicDecimal128{0, 0}, DecimalScalar128{0, 0, 38, 2}, 0, false),
-      std::make_tuple(INT64_MAX, DecimalScalar128{INT64_MAX, 38, 0}, 0, false),
+      std::make_tuple(DecimalScalar128{INT64_MIN, 38, 0},
+                      DecimalScalar128{INT64_MIN, 38, 0}, 0, false),
+      std::make_tuple(DecimalScalar128{0, 0, 36, 0}, DecimalScalar128{0, 0, 38, 2}, 0,
+                      false),
+      std::make_tuple(DecimalScalar128{INT64_MAX, 38, 0},
+                      DecimalScalar128{INT64_MAX, 38, 0}, 0, false),
 
-      std::make_tuple(BasicDecimal128(INT64_MAX / 100),
+      std::make_tuple(DecimalScalar128{INT64_MAX / 100, 36, 0},
                       DecimalScalar128{INT64_MAX, 38, 2}, 0, false),
 
       // large scales
-      std::make_tuple(BasicDecimal128{0, 0}, DecimalScalar128{12345, 38, 16}, 0, false),
+      std::make_tuple(DecimalScalar128{0, 0, 22, 0}, DecimalScalar128{12345, 38, 16}, 0,
+                      false),
+
       std::make_tuple(
-          BasicDecimal128{124},
+          DecimalScalar128{BasicDecimal128{124}, 22, 0},
           DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(14), 38, 16}, 0, false),
       std::make_tuple(
-          BasicDecimal128{-124},
+          DecimalScalar128{BasicDecimal128{-124}, 22, 0},
           DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(14), 38, 16}, 0,
           false),
       std::make_tuple(
-          BasicDecimal128{124},
+          DecimalScalar128{BasicDecimal128{124}, 6, 0},
           DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(30), 38, 32}, 0, false),
       std::make_tuple(
-          BasicDecimal128{-124},
+          DecimalScalar128{BasicDecimal128{-124}, 6, 0},
           DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(30), 38, 32}, 0,
           false),
 
-      // overflow
+      // scale bigger than arg
       std::make_tuple(
-          BasicDecimal128{0, 0},
-          DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(32), 38, 32}, 35, true),
+          DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(32), 38, 32},
+          DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(32), 38, 32}, 35,
+          false),
       std::make_tuple(
-          BasicDecimal128{0, 0},
+          DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(32), 38, 32},
           DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(32), 38, 32}, 35,
-          true),
+          false),
+
+      // overflow
+      std::make_tuple(DecimalScalar128{0, 0, 1, 0}, DecimalScalar128{99, 2, 1}, 0, true),
   };
 
   for (auto iter : test_values) {
@@ -609,7 +625,9 @@ TEST_F(TestDecimalSql, Round) {
     auto expected_overflow = std::get<3>(iter);
     bool overflow = false;
 
-    EXPECT_EQ(expected, decimalops::Round(input, rounding_scale, &overflow))
+    EXPECT_EQ(expected.value(),
+              decimalops::Round(input, expected.precision(), expected.scale(),
+                                rounding_scale, &overflow))
         << "  failed on input " << input << "  rounding scale " << rounding_scale;
     if (expected_overflow) {
       ASSERT_TRUE(overflow) << "overflow expected for input " << input;
@@ -621,53 +639,64 @@ TEST_F(TestDecimalSql, Round) {
 
 TEST_F(TestDecimalSql, Truncate) {
   // expected, input, rounding_scale, overflow
-  using TupleType = std::tuple<BasicDecimal128, DecimalScalar128, int32_t, bool>;
+  using TupleType = std::tuple<DecimalScalar128, DecimalScalar128, int32_t, bool>;
   std::vector<TupleType> test_values = {
       // examples from
       // https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_truncate
-      std::make_tuple(BasicDecimal128{12}, DecimalScalar128{1223, 38, 3}, 1, false),
-      std::make_tuple(BasicDecimal128{19}, DecimalScalar128{1999, 38, 3}, 1, false),
-      std::make_tuple(BasicDecimal128{1}, DecimalScalar128{1999, 38, 3}, 0, false),
-      std::make_tuple(BasicDecimal128{-19}, DecimalScalar128{-1999, 38, 3}, 1, false),
-      std::make_tuple(BasicDecimal128{100}, DecimalScalar128{122, 38, 0}, -2, false),
-      std::make_tuple(BasicDecimal128{1028}, DecimalScalar128{1028, 38, 0}, 0, false),
+      std::make_tuple(DecimalScalar128{12, 36, 1}, DecimalScalar128{1223, 38, 3}, 1,
+                      false),
+      std::make_tuple(DecimalScalar128{19, 36, 1}, DecimalScalar128{1999, 38, 3}, 1,
+                      false),
+      std::make_tuple(DecimalScalar128{1, 35, 0}, DecimalScalar128{1999, 38, 3}, 0,
+                      false),
+      std::make_tuple(DecimalScalar128{-19, 36, 1}, DecimalScalar128{-1999, 38, 3}, 1,
+                      false),
+      std::make_tuple(DecimalScalar128{100, 38, 0}, DecimalScalar128{122, 38, 0}, -2,
+                      false),
+      std::make_tuple(DecimalScalar128{1028, 38, 0}, DecimalScalar128{1028, 38, 0}, 0,
+                      false),
 
       // border cases
-      std::make_tuple(BasicDecimal128{INT64_MIN / 100},
+      std::make_tuple(DecimalScalar128{BasicDecimal128{INT64_MIN / 100}, 36, 0},
                       DecimalScalar128{INT64_MIN, 38, 2}, 0, false),
 
-      std::make_tuple(INT64_MIN, DecimalScalar128{INT64_MIN, 38, 0}, 0, false),
-      std::make_tuple(BasicDecimal128{0, 0}, DecimalScalar128{0, 0, 38, 2}, 0, false),
-      std::make_tuple(INT64_MAX, DecimalScalar128{INT64_MAX, 38, 0}, 0, false),
+      std::make_tuple(DecimalScalar128{INT64_MIN, 38, 0},
+                      DecimalScalar128{INT64_MIN, 38, 0}, 0, false),
+      std::make_tuple(DecimalScalar128{0, 0, 38, 0}, DecimalScalar128{0, 0, 38, 2}, 0,
+                      false),
+      std::make_tuple(DecimalScalar128{INT64_MAX, 38, 0},
+                      DecimalScalar128{INT64_MAX, 38, 0}, 0, false),
 
-      std::make_tuple(BasicDecimal128(INT64_MAX / 100),
+      std::make_tuple(DecimalScalar128{BasicDecimal128(INT64_MAX / 100), 36, 0},
                       DecimalScalar128{INT64_MAX, 38, 2}, 0, false),
 
       // large scales
-      std::make_tuple(BasicDecimal128{0, 0}, DecimalScalar128{12345, 38, 16}, 0, false),
+      std::make_tuple(DecimalScalar128{BasicDecimal128{0, 0}, 22, 0},
+                      DecimalScalar128{12345, 38, 16}, 0, false),
       std::make_tuple(
-          BasicDecimal128{123},
+          DecimalScalar128{BasicDecimal128{123}, 22, 0},
           DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(14), 38, 16}, 0, false),
       std::make_tuple(
-          BasicDecimal128{-123},
+          DecimalScalar128{BasicDecimal128{-123}, 22, 0},
           DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(14), 38, 16}, 0,
           false),
       std::make_tuple(
-          BasicDecimal128{123},
+          DecimalScalar128{BasicDecimal128{123}, 6, 0},
           DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(30), 38, 32}, 0, false),
       std::make_tuple(
-          BasicDecimal128{-123},
+          DecimalScalar128{BasicDecimal128{-123}, 6, 0},
           DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(30), 38, 32}, 0,
           false),
 
       // overflow
       std::make_tuple(
-          BasicDecimal128{0, 0},
-          DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(32), 38, 32}, 35, true),
+          DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(32), 38, 32},
+          DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(32), 38, 32}, 35,
+          false),
       std::make_tuple(
-          BasicDecimal128{0, 0},
+          DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(32), 38, 32},
           DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(32), 38, 32}, 35,
-          true),
+          false),
   };
 
   for (auto iter : test_values) {
@@ -677,7 +706,9 @@ TEST_F(TestDecimalSql, Truncate) {
     auto expected_overflow = std::get<3>(iter);
     bool overflow = false;
 
-    EXPECT_EQ(expected, decimalops::Truncate(input, rounding_scale, &overflow))
+    EXPECT_EQ(expected.value(),
+              decimalops::Truncate(input, expected.precision(), expected.scale(),
+                                   rounding_scale, &overflow))
         << "  failed on input " << input << "  rounding scale " << rounding_scale;
     if (expected_overflow) {
       ASSERT_TRUE(overflow) << "overflow expected for input " << input;
diff --git a/cpp/src/gandiva/precompiled/decimal_wrapper.cc b/cpp/src/gandiva/precompiled/decimal_wrapper.cc
index 80d4832..082d583 100644
--- a/cpp/src/gandiva/precompiled/decimal_wrapper.cc
+++ b/cpp/src/gandiva/precompiled/decimal_wrapper.cc
@@ -137,7 +137,7 @@ void round_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
   gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale);
 
   bool overflow = false;
-  auto out = gandiva::decimalops::Round(x, 0, &overflow);
+  auto out = gandiva::decimalops::Round(x, out_precision, 0, 0, &overflow);
   *out_high = out.high_bits();
   *out_low = out.low_bits();
 }
@@ -150,7 +150,8 @@ void round_decimal128_int32(int64_t x_high, uint64_t x_low, int32_t x_precision,
   gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale);
 
   bool overflow = false;
-  auto out = gandiva::decimalops::Round(x, rounding_scale, &overflow);
+  auto out =
+      gandiva::decimalops::Round(x, out_precision, out_scale, rounding_scale, &overflow);
   *out_high = out.high_bits();
   *out_low = out.low_bits();
 }
@@ -162,7 +163,7 @@ void truncate_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision,
   gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale);
 
   bool overflow = false;
-  auto out = gandiva::decimalops::Truncate(x, 0, &overflow);
+  auto out = gandiva::decimalops::Truncate(x, out_precision, 0, 0, &overflow);
   *out_high = out.high_bits();
   *out_low = out.low_bits();
 }
@@ -175,7 +176,8 @@ void truncate_decimal128_int32(int64_t x_high, uint64_t x_low, int32_t x_precisi
   gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale);
 
   bool overflow = false;
-  auto out = gandiva::decimalops::Truncate(x, rounding_scale, &overflow);
+  auto out = gandiva::decimalops::Truncate(x, out_precision, out_scale, rounding_scale,
+                                           &overflow);
   *out_high = out.high_bits();
   *out_low = out.low_bits();
 }
diff --git a/cpp/src/gandiva/precompiled/extended_math_ops.cc b/cpp/src/gandiva/precompiled/extended_math_ops.cc
index e6b6a6e..78a3993 100644
--- a/cpp/src/gandiva/precompiled/extended_math_ops.cc
+++ b/cpp/src/gandiva/precompiled/extended_math_ops.cc
@@ -115,8 +115,9 @@ FORCE_INLINE
 gdv_int64 truncate_int64_int32(gdv_int64 in, gdv_int32 out_scale) {
   bool overflow = false;
   arrow::BasicDecimal128 decimal = gandiva::decimalops::FromInt64(in, 38, 0, &overflow);
-  arrow::BasicDecimal128 decimal_with_outscale = gandiva::decimalops::Truncate(
-      gandiva::BasicDecimalScalar128(decimal, 38, 0), out_scale, &overflow);
+  arrow::BasicDecimal128 decimal_with_outscale =
+      gandiva::decimalops::Truncate(gandiva::BasicDecimalScalar128(decimal, 38, 0), 38,
+                                    out_scale, out_scale, &overflow);
   if (out_scale < 0) {
     out_scale = 0;
   }