You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ra...@apache.org on 2019/04/29 07:29:23 UTC

[arrow] branch master updated: ARROW-5226: [Gandiva] Add cmp functions for decimals

This is an automated email from the ASF dual-hosted git repository.

ravindra pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 6106ac9  ARROW-5226: [Gandiva] Add cmp functions for decimals
6106ac9 is described below

commit 6106ac9d6c4db0a1747e2caa79a4e4974b455e96
Author: Pindikura Ravindra <ra...@dremio.com>
AuthorDate: Mon Apr 29 12:58:49 2019 +0530

    ARROW-5226: [Gandiva] Add cmp functions for decimals
    
    Author: Pindikura Ravindra <ra...@dremio.com>
    
    Closes #4219 from pravindra/ARROW-5226 and squashes the following commits:
    
    bb1cb3dc <Pindikura Ravindra> ARROW-5226:  Add cmp functions for decimals
---
 cpp/src/gandiva/decimal_ir.cc                   | 55 ++++++++++++++++++++
 cpp/src/gandiva/decimal_ir.h                    |  3 ++
 cpp/src/gandiva/decimal_xlarge.cc               | 33 ++++++++++++
 cpp/src/gandiva/decimal_xlarge.h                |  3 ++
 cpp/src/gandiva/function_registry_arithmetic.cc |  6 +++
 cpp/src/gandiva/precompiled/decimal_ops.cc      | 40 +++++++++++++++
 cpp/src/gandiva/precompiled/decimal_ops.h       |  6 +++
 cpp/src/gandiva/precompiled/decimal_ops_test.cc | 68 +++++++++++++++++++++++++
 cpp/src/gandiva/precompiled/decimal_wrapper.cc  | 11 ++++
 cpp/src/gandiva/tests/decimal_test.cc           | 63 +++++++++++++++++++++++
 10 files changed, 288 insertions(+)

diff --git a/cpp/src/gandiva/decimal_ir.cc b/cpp/src/gandiva/decimal_ir.cc
index 47e60cf..6344332 100644
--- a/cpp/src/gandiva/decimal_ir.cc
+++ b/cpp/src/gandiva/decimal_ir.cc
@@ -460,6 +460,48 @@ Status DecimalIR::BuildDivideOrMod(const std::string& function_name,
   return Status::OK();
 }
 
+Status DecimalIR::BuildCompare(const std::string& function_name,
+                               llvm::ICmpInst::Predicate cmp_instruction) {
+  // Create fn prototype :
+  // bool
+  // function_name(int128_t x_value, int32_t x_precision, int32_t x_scale,
+  //               int128_t y_value, int32_t y_precision, int32_t y_scale)
+
+  auto i32 = types()->i32_type();
+  auto i128 = types()->i128_type();
+  auto function = BuildFunction(function_name, types()->i1_type(),
+                                {
+                                    {"x_value", i128},
+                                    {"x_precision", i32},
+                                    {"x_scale", i32},
+                                    {"y_value", i128},
+                                    {"y_precision", i32},
+                                    {"y_scale", i32},
+                                });
+
+  auto arg_iter = function->arg_begin();
+  ValueFull x(&arg_iter[0], &arg_iter[1], &arg_iter[2]);
+  ValueFull y(&arg_iter[3], &arg_iter[4], &arg_iter[5]);
+
+  auto entry = llvm::BasicBlock::Create(*context(), "entry", function);
+  ir_builder()->SetInsertPoint(entry);
+
+  // Make call to pre-compiled IR function.
+  auto x_split = ValueSplit::MakeFromInt128(this, x.value());
+  auto y_split = ValueSplit::MakeFromInt128(this, y.value());
+
+  std::vector<llvm::Value*> args = {
+      x_split.high(), x_split.low(), x.precision(), x.scale(),
+      y_split.high(), y_split.low(), y.precision(), y.scale(),
+  };
+  auto cmp_value = ir_builder()->CreateCall(
+      module()->getFunction("compare_internal_decimal128_decimal128"), args);
+  auto result =
+      ir_builder()->CreateICmp(cmp_instruction, cmp_value, types()->i32_constant(0));
+  ir_builder()->CreateRet(result);
+  return Status::OK();
+}
+
 Status DecimalIR::AddFunctions(Engine* engine) {
   auto decimal_ir = std::make_shared<DecimalIR>(engine);
 
@@ -476,6 +518,19 @@ Status DecimalIR::AddFunctions(Engine* engine) {
       "divide_decimal128_decimal128", "divide_internal_decimal128_decimal128"));
   ARROW_RETURN_NOT_OK(decimal_ir->BuildDivideOrMod("mod_decimal128_decimal128",
                                                    "mod_internal_decimal128_decimal128"));
+
+  ARROW_RETURN_NOT_OK(
+      decimal_ir->BuildCompare("equal_decimal128_decimal128", llvm::ICmpInst::ICMP_EQ));
+  ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare("not_equal_decimal128_decimal128",
+                                               llvm::ICmpInst::ICMP_NE));
+  ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare("less_than_decimal128_decimal128",
+                                               llvm::ICmpInst::ICMP_SLT));
+  ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare(
+      "less_than_or_equal_to_decimal128_decimal128", llvm::ICmpInst::ICMP_SLE));
+  ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare("greater_than_decimal128_decimal128",
+                                               llvm::ICmpInst::ICMP_SGT));
+  ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare(
+      "greater_than_or_equal_to_decimal128_decimal128", llvm::ICmpInst::ICMP_SGE));
   return Status::OK();
 }
 
diff --git a/cpp/src/gandiva/decimal_ir.h b/cpp/src/gandiva/decimal_ir.h
index 048b9d3..b1bf38d 100644
--- a/cpp/src/gandiva/decimal_ir.h
+++ b/cpp/src/gandiva/decimal_ir.h
@@ -153,6 +153,9 @@ class DecimalIR : public FunctionIRBuilder {
   Status BuildDivideOrMod(const std::string& function_name,
                           const std::string& internal_name);
 
+  Status BuildCompare(const std::string& function_name,
+                      llvm::ICmpInst::Predicate cmp_instruction);
+
   // Add a trace in IR code.
   void AddTrace(const std::string& fmt, std::vector<llvm::Value*> args);
 
diff --git a/cpp/src/gandiva/decimal_xlarge.cc b/cpp/src/gandiva/decimal_xlarge.cc
index 4a8f3e5..60917ed 100644
--- a/cpp/src/gandiva/decimal_xlarge.cc
+++ b/cpp/src/gandiva/decimal_xlarge.cc
@@ -82,6 +82,17 @@ void ExportedDecimalFunctions::AddMappings(Engine* engine) const {
 
   engine->AddGlobalMappingForFunc("gdv_xlarge_mod", types->void_type() /*return_type*/,
                                   args, reinterpret_cast<void*>(gdv_xlarge_mod));
+
+  // gdv_xlarge_compare
+  args = {types->i64_type(),   // int64_t x_high
+          types->i64_type(),   // uint64_t x_low
+          types->i32_type(),   // int32_t x_scale
+          types->i64_type(),   // int64_t y_high
+          types->i64_type(),   // uint64_t y_low
+          types->i32_type()};  // int32_t y_scale
+
+  engine->AddGlobalMappingForFunc("gdv_xlarge_compare", types->i32_type() /*return_type*/,
+                                  args, reinterpret_cast<void*>(gdv_xlarge_mod));
 }
 
 }  // namespace gandiva
@@ -248,4 +259,26 @@ void gdv_xlarge_mod(int64_t x_high, uint64_t x_low, int32_t x_scale, int64_t y_h
   *out_low = result.low_bits();
 }
 
+int32_t gdv_xlarge_compare(int64_t x_high, uint64_t x_low, int32_t x_scale,
+                           int64_t y_high, uint64_t y_low, int32_t y_scale) {
+  BasicDecimal128 x{x_high, x_low};
+  BasicDecimal128 y{y_high, y_low};
+
+  int256_t x_large = gandiva::internal::ConvertToInt256(x);
+  int256_t y_large = gandiva::internal::ConvertToInt256(y);
+  if (x_scale < y_scale) {
+    x_large = gandiva::internal::IncreaseScaleBy(x_large, y_scale - x_scale);
+  } else {
+    y_large = gandiva::internal::IncreaseScaleBy(y_large, x_scale - y_scale);
+  }
+
+  if (x_large == y_large) {
+    return 0;
+  } else if (x_large < y_large) {
+    return -1;
+  } else {
+    return 1;
+  }
+}
+
 }  // extern "C"
diff --git a/cpp/src/gandiva/decimal_xlarge.h b/cpp/src/gandiva/decimal_xlarge.h
index c2e2dd8..2643297 100644
--- a/cpp/src/gandiva/decimal_xlarge.h
+++ b/cpp/src/gandiva/decimal_xlarge.h
@@ -35,4 +35,7 @@ void gdv_xlarge_scale_up_and_divide(int64_t x_high, uint64_t x_low, int64_t y_hi
 void gdv_xlarge_mod(int64_t x_high, uint64_t x_low, int32_t x_scale, int64_t y_high,
                     uint64_t y_low, int32_t y_scale, int64_t* out_high,
                     uint64_t* out_low);
+
+int32_t gdv_xlarge_compare(int64_t x_high, uint64_t x_low, int32_t x_scale,
+                           int64_t y_high, uint64_t y_low, int32_t y_scale);
 }
diff --git a/cpp/src/gandiva/function_registry_arithmetic.cc b/cpp/src/gandiva/function_registry_arithmetic.cc
index ad8445b..04e9113 100644
--- a/cpp/src/gandiva/function_registry_arithmetic.cc
+++ b/cpp/src/gandiva/function_registry_arithmetic.cc
@@ -59,6 +59,12 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
       BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(multiply, decimal128),
       BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(divide, decimal128),
       BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, decimal128),
+      BINARY_RELATIONAL_SAFE_NULL_IF_NULL(equal, decimal128),
+      BINARY_RELATIONAL_SAFE_NULL_IF_NULL(not_equal, decimal128),
+      BINARY_RELATIONAL_SAFE_NULL_IF_NULL(less_than, decimal128),
+      BINARY_RELATIONAL_SAFE_NULL_IF_NULL(less_than_or_equal_to, decimal128),
+      BINARY_RELATIONAL_SAFE_NULL_IF_NULL(greater_than, decimal128),
+      BINARY_RELATIONAL_SAFE_NULL_IF_NULL(greater_than_or_equal_to, decimal128),
 
       BINARY_RELATIONAL_BOOL_FN(equal),
       BINARY_RELATIONAL_BOOL_FN(not_equal),
diff --git a/cpp/src/gandiva/precompiled/decimal_ops.cc b/cpp/src/gandiva/precompiled/decimal_ops.cc
index e13a5d8..9bf643f 100644
--- a/cpp/src/gandiva/precompiled/decimal_ops.cc
+++ b/cpp/src/gandiva/precompiled/decimal_ops.cc
@@ -423,5 +423,45 @@ BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x,
   return result;
 }
 
+int32_t CompareSameScale(const BasicDecimal128& x, const BasicDecimal128& y) {
+  if (x == y) {
+    return 0;
+  } else if (x < y) {
+    return -1;
+  } else {
+    return 1;
+  }
+}
+
+int32_t Compare(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y) {
+  int32_t delta_scale = x.scale() - y.scale();
+
+  // fast-path : both are of the same scale.
+  if (delta_scale == 0) {
+    return CompareSameScale(x.value(), y.value());
+  }
+
+  // Check if we'll need more than 256-bits after adjusting the scale.
+  bool need256 =
+      (delta_scale < 0 && x.precision() - delta_scale > DecimalTypeUtil::kMaxPrecision) ||
+      (y.precision() + delta_scale > DecimalTypeUtil::kMaxPrecision);
+  if (need256) {
+    return gdv_xlarge_compare(x.value().high_bits(), x.value().low_bits(), x.scale(),
+                              y.value().high_bits(), y.value().low_bits(), y.scale());
+  } else {
+    BasicDecimal128 x_scaled;
+    BasicDecimal128 y_scaled;
+
+    if (delta_scale < 0) {
+      x_scaled = x.value().IncreaseScaleBy(-delta_scale);
+      y_scaled = y.value();
+    } else {
+      x_scaled = x.value();
+      y_scaled = y.value().IncreaseScaleBy(delta_scale);
+    }
+    return CompareSameScale(x_scaled, y_scaled);
+  }
+}
+
 }  // namespace decimalops
 }  // namespace gandiva
diff --git a/cpp/src/gandiva/precompiled/decimal_ops.h b/cpp/src/gandiva/precompiled/decimal_ops.h
index e0aea7e..19417a2 100644
--- a/cpp/src/gandiva/precompiled/decimal_ops.h
+++ b/cpp/src/gandiva/precompiled/decimal_ops.h
@@ -50,5 +50,11 @@ arrow::BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x,
                            const BasicDecimalScalar128& y, int32_t out_precision,
                            int32_t out_scale, bool* overflow);
 
+/// Compare two decimals. Returns :
+///  0 if x == y
+///  1 if x > y
+/// -1 if x < y
+int32_t Compare(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y);
+
 }  // namespace decimalops
 }  // namespace gandiva
diff --git a/cpp/src/gandiva/precompiled/decimal_ops_test.cc b/cpp/src/gandiva/precompiled/decimal_ops_test.cc
index 05289e2..f3b22ff 100644
--- a/cpp/src/gandiva/precompiled/decimal_ops_test.cc
+++ b/cpp/src/gandiva/precompiled/decimal_ops_test.cc
@@ -479,4 +479,72 @@ TEST_F(TestDecimalSql, DivideByZero) {
   EXPECT_FALSE(context.has_error());
 }
 
+TEST_F(TestDecimalSql, Compare) {
+  // x.scale == y.scale
+  EXPECT_EQ(
+      0, decimalops::Compare(DecimalScalar128{100, 38, 6}, DecimalScalar128{100, 38, 6}));
+  EXPECT_EQ(
+      1, decimalops::Compare(DecimalScalar128{200, 38, 6}, DecimalScalar128{100, 38, 6}));
+  EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{100, 38, 6},
+                                    DecimalScalar128{200, 38, 6}));
+
+  // x.scale == y.scale, with -ve.
+  EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{-100, 38, 6},
+                                   DecimalScalar128{-100, 38, 6}));
+  EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{-200, 38, 6},
+                                    DecimalScalar128{-100, 38, 6}));
+  EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{-100, 38, 6},
+                                   DecimalScalar128{-200, 38, 6}));
+  EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{100, 38, 6},
+                                   DecimalScalar128{-200, 38, 6}));
+
+  for (int32_t precision : {16, 36, 38}) {
+    // x_scale > y_scale
+    EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{10000, precision, 6},
+                                     DecimalScalar128{100, precision, 4}));
+    EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{20000, precision, 6},
+                                     DecimalScalar128{100, precision, 4}));
+    EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{10000, precision, 6},
+                                      DecimalScalar128{200, precision, 4}));
+
+    // x.scale > y.scale, with -ve
+    EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{-10000, precision, 6},
+                                     DecimalScalar128{-100, precision, 4}));
+    EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{-20000, precision, 6},
+                                      DecimalScalar128{-100, precision, 4}));
+    EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{-10000, precision, 6},
+                                     DecimalScalar128{-200, precision, 4}));
+    EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{10000, precision, 6},
+                                     DecimalScalar128{-200, precision, 4}));
+
+    // x.scale < y.scale
+    EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{100, precision, 4},
+                                     DecimalScalar128{10000, precision, 6}));
+    EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{200, precision, 4},
+                                     DecimalScalar128{10000, precision, 6}));
+    EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{100, precision, 4},
+                                      DecimalScalar128{20000, precision, 6}));
+
+    // x.scale < y.scale, with -ve
+    EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{-100, precision, 4},
+                                     DecimalScalar128{-10000, precision, 6}));
+    EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{-200, precision, 4},
+                                      DecimalScalar128{-10000, precision, 6}));
+    EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{-100, precision, 4},
+                                     DecimalScalar128{-20000, precision, 6}));
+    EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{100, precision, 4},
+                                     DecimalScalar128{-200, precision, 6}));
+  }
+
+  // large cases.
+  EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{kThirtyEight9s, 38, 6},
+                                   DecimalScalar128{kThirtyEight9s, 38, 6}));
+
+  EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{kThirtyEight9s, 38, 6},
+                                   DecimalScalar128{kThirtySix9s, 38, 4}));
+
+  EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{kThirtyEight9s, 38, 6},
+                                    DecimalScalar128{kThirtyEight9s, 38, 4}));
+}
+
 }  // namespace gandiva
diff --git a/cpp/src/gandiva/precompiled/decimal_wrapper.cc b/cpp/src/gandiva/precompiled/decimal_wrapper.cc
index d5c919e..69a3b70 100644
--- a/cpp/src/gandiva/precompiled/decimal_wrapper.cc
+++ b/cpp/src/gandiva/precompiled/decimal_wrapper.cc
@@ -86,4 +86,15 @@ void mod_internal_decimal128_decimal128(int64_t context, int64_t x_high, uint64_
   *out_low = out.low_bits();
 }
 
+FORCE_INLINE
+int32_t compare_internal_decimal128_decimal128(int64_t x_high, uint64_t x_low,
+                                               int32_t x_precision, int32_t x_scale,
+                                               int64_t y_high, uint64_t y_low,
+                                               int32_t y_precision, int32_t y_scale) {
+  gandiva::BasicDecimalScalar128 x(x_high, x_low, x_precision, x_scale);
+  gandiva::BasicDecimalScalar128 y(y_high, y_low, y_precision, y_scale);
+
+  return gandiva::decimalops::Compare(x, y);
+}
+
 }  // extern "C"
diff --git a/cpp/src/gandiva/tests/decimal_test.cc b/cpp/src/gandiva/tests/decimal_test.cc
index da93b0e..08435e4 100644
--- a/cpp/src/gandiva/tests/decimal_test.cc
+++ b/cpp/src/gandiva/tests/decimal_test.cc
@@ -27,6 +27,7 @@
 #include "gandiva/tests/test_util.h"
 #include "gandiva/tree_expr_builder.h"
 
+using arrow::boolean;
 using arrow::Decimal128;
 
 namespace gandiva {
@@ -234,4 +235,66 @@ TEST_F(TestDecimal, TestIfElse) {
   EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
 }
 
+TEST_F(TestDecimal, TestCompare) {
+  // schema for input fields
+  constexpr int32_t precision = 36;
+  constexpr int32_t scale = 18;
+  auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+  auto field_a = field("a", decimal_type);
+  auto field_b = field("b", decimal_type);
+  auto schema = arrow::schema({field_a, field_b});
+
+  // build expressions
+  auto exprs = std::vector<ExpressionPtr>{
+      TreeExprBuilder::MakeExpression("equal", {field_a, field_b},
+                                      field("res_eq", boolean())),
+      TreeExprBuilder::MakeExpression("not_equal", {field_a, field_b},
+                                      field("res_ne", boolean())),
+      TreeExprBuilder::MakeExpression("less_than", {field_a, field_b},
+                                      field("res_lt", boolean())),
+      TreeExprBuilder::MakeExpression("less_than_or_equal_to", {field_a, field_b},
+                                      field("res_le", boolean())),
+      TreeExprBuilder::MakeExpression("greater_than", {field_a, field_b},
+                                      field("res_gt", boolean())),
+      TreeExprBuilder::MakeExpression("greater_than_or_equal_to", {field_a, field_b},
+                                      field("res_ge", boolean())),
+  };
+
+  // Build a projector for the expression.
+  std::shared_ptr<Projector> projector;
+  auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector);
+  DCHECK_OK(status);
+
+  // Create a row-batch with some sample data
+  int num_records = 4;
+  auto array_a =
+      MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"1", "2", "3", "-4"}, scale),
+                            {true, true, true, true});
+  auto array_b =
+      MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"1", "3", "2", "-3"}, scale),
+                            {true, true, true, true});
+
+  // prepare input record batch
+  auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b});
+
+  // Evaluate expression
+  arrow::ArrayVector outputs;
+  status = projector->Evaluate(*in_batch, pool_, &outputs);
+  DCHECK_OK(status);
+
+  // Validate results
+  EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({true, false, false, false}),
+                            outputs[0]);  // equal
+  EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({false, true, true, true}),
+                            outputs[1]);  // not_equal
+  EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({false, true, false, true}),
+                            outputs[2]);  // less_than
+  EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({true, true, false, true}),
+                            outputs[3]);  // less_than_or_equal_to
+  EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({false, false, true, false}),
+                            outputs[4]);  // greater_than
+  EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({true, false, true, false}),
+                            outputs[5]);  // greater_than_or_equal_to
+}
+
 }  // namespace gandiva