You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ra...@apache.org on 2019/04/29 07:29:23 UTC
[arrow] branch master updated: ARROW-5226: [Gandiva] Add cmp
functions for decimals
This is an automated email from the ASF dual-hosted git repository.
ravindra pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 6106ac9 ARROW-5226: [Gandiva] Add cmp functions for decimals
6106ac9 is described below
commit 6106ac9d6c4db0a1747e2caa79a4e4974b455e96
Author: Pindikura Ravindra <ra...@dremio.com>
AuthorDate: Mon Apr 29 12:58:49 2019 +0530
ARROW-5226: [Gandiva] Add cmp functions for decimals
Author: Pindikura Ravindra <ra...@dremio.com>
Closes #4219 from pravindra/ARROW-5226 and squashes the following commits:
bb1cb3dc <Pindikura Ravindra> ARROW-5226: Add cmp functions for decimals
---
cpp/src/gandiva/decimal_ir.cc | 55 ++++++++++++++++++++
cpp/src/gandiva/decimal_ir.h | 3 ++
cpp/src/gandiva/decimal_xlarge.cc | 33 ++++++++++++
cpp/src/gandiva/decimal_xlarge.h | 3 ++
cpp/src/gandiva/function_registry_arithmetic.cc | 6 +++
cpp/src/gandiva/precompiled/decimal_ops.cc | 40 +++++++++++++++
cpp/src/gandiva/precompiled/decimal_ops.h | 6 +++
cpp/src/gandiva/precompiled/decimal_ops_test.cc | 68 +++++++++++++++++++++++++
cpp/src/gandiva/precompiled/decimal_wrapper.cc | 11 ++++
cpp/src/gandiva/tests/decimal_test.cc | 63 +++++++++++++++++++++++
10 files changed, 288 insertions(+)
diff --git a/cpp/src/gandiva/decimal_ir.cc b/cpp/src/gandiva/decimal_ir.cc
index 47e60cf..6344332 100644
--- a/cpp/src/gandiva/decimal_ir.cc
+++ b/cpp/src/gandiva/decimal_ir.cc
@@ -460,6 +460,48 @@ Status DecimalIR::BuildDivideOrMod(const std::string& function_name,
return Status::OK();
}
+Status DecimalIR::BuildCompare(const std::string& function_name,
+ llvm::ICmpInst::Predicate cmp_instruction) {
+ // Create fn prototype :
+ // bool
+ // function_name(int128_t x_value, int32_t x_precision, int32_t x_scale,
+ // int128_t y_value, int32_t y_precision, int32_t y_scale)
+
+ auto i32 = types()->i32_type();
+ auto i128 = types()->i128_type();
+ auto function = BuildFunction(function_name, types()->i1_type(),
+ {
+ {"x_value", i128},
+ {"x_precision", i32},
+ {"x_scale", i32},
+ {"y_value", i128},
+ {"y_precision", i32},
+ {"y_scale", i32},
+ });
+
+ auto arg_iter = function->arg_begin();
+ ValueFull x(&arg_iter[0], &arg_iter[1], &arg_iter[2]);
+ ValueFull y(&arg_iter[3], &arg_iter[4], &arg_iter[5]);
+
+ auto entry = llvm::BasicBlock::Create(*context(), "entry", function);
+ ir_builder()->SetInsertPoint(entry);
+
+ // Make call to pre-compiled IR function.
+ auto x_split = ValueSplit::MakeFromInt128(this, x.value());
+ auto y_split = ValueSplit::MakeFromInt128(this, y.value());
+
+ std::vector<llvm::Value*> args = {
+ x_split.high(), x_split.low(), x.precision(), x.scale(),
+ y_split.high(), y_split.low(), y.precision(), y.scale(),
+ };
+ auto cmp_value = ir_builder()->CreateCall(
+ module()->getFunction("compare_internal_decimal128_decimal128"), args);
+ auto result =
+ ir_builder()->CreateICmp(cmp_instruction, cmp_value, types()->i32_constant(0));
+ ir_builder()->CreateRet(result);
+ return Status::OK();
+}
+
Status DecimalIR::AddFunctions(Engine* engine) {
auto decimal_ir = std::make_shared<DecimalIR>(engine);
@@ -476,6 +518,19 @@ Status DecimalIR::AddFunctions(Engine* engine) {
"divide_decimal128_decimal128", "divide_internal_decimal128_decimal128"));
ARROW_RETURN_NOT_OK(decimal_ir->BuildDivideOrMod("mod_decimal128_decimal128",
"mod_internal_decimal128_decimal128"));
+
+ ARROW_RETURN_NOT_OK(
+ decimal_ir->BuildCompare("equal_decimal128_decimal128", llvm::ICmpInst::ICMP_EQ));
+ ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare("not_equal_decimal128_decimal128",
+ llvm::ICmpInst::ICMP_NE));
+ ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare("less_than_decimal128_decimal128",
+ llvm::ICmpInst::ICMP_SLT));
+ ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare(
+ "less_than_or_equal_to_decimal128_decimal128", llvm::ICmpInst::ICMP_SLE));
+ ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare("greater_than_decimal128_decimal128",
+ llvm::ICmpInst::ICMP_SGT));
+ ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare(
+ "greater_than_or_equal_to_decimal128_decimal128", llvm::ICmpInst::ICMP_SGE));
return Status::OK();
}
diff --git a/cpp/src/gandiva/decimal_ir.h b/cpp/src/gandiva/decimal_ir.h
index 048b9d3..b1bf38d 100644
--- a/cpp/src/gandiva/decimal_ir.h
+++ b/cpp/src/gandiva/decimal_ir.h
@@ -153,6 +153,9 @@ class DecimalIR : public FunctionIRBuilder {
Status BuildDivideOrMod(const std::string& function_name,
const std::string& internal_name);
+ Status BuildCompare(const std::string& function_name,
+ llvm::ICmpInst::Predicate cmp_instruction);
+
// Add a trace in IR code.
void AddTrace(const std::string& fmt, std::vector<llvm::Value*> args);
diff --git a/cpp/src/gandiva/decimal_xlarge.cc b/cpp/src/gandiva/decimal_xlarge.cc
index 4a8f3e5..60917ed 100644
--- a/cpp/src/gandiva/decimal_xlarge.cc
+++ b/cpp/src/gandiva/decimal_xlarge.cc
@@ -82,6 +82,17 @@ void ExportedDecimalFunctions::AddMappings(Engine* engine) const {
engine->AddGlobalMappingForFunc("gdv_xlarge_mod", types->void_type() /*return_type*/,
args, reinterpret_cast<void*>(gdv_xlarge_mod));
+
+ // gdv_xlarge_compare
+ args = {types->i64_type(), // int64_t x_high
+ types->i64_type(), // uint64_t x_low
+ types->i32_type(), // int32_t x_scale
+ types->i64_type(), // int64_t y_high
+ types->i64_type(), // uint64_t y_low
+ types->i32_type()}; // int32_t y_scale
+
+ engine->AddGlobalMappingForFunc("gdv_xlarge_compare", types->i32_type() /*return_type*/,
+ args, reinterpret_cast<void*>(gdv_xlarge_mod));
}
} // namespace gandiva
@@ -248,4 +259,26 @@ void gdv_xlarge_mod(int64_t x_high, uint64_t x_low, int32_t x_scale, int64_t y_h
*out_low = result.low_bits();
}
+int32_t gdv_xlarge_compare(int64_t x_high, uint64_t x_low, int32_t x_scale,
+ int64_t y_high, uint64_t y_low, int32_t y_scale) {
+ BasicDecimal128 x{x_high, x_low};
+ BasicDecimal128 y{y_high, y_low};
+
+ int256_t x_large = gandiva::internal::ConvertToInt256(x);
+ int256_t y_large = gandiva::internal::ConvertToInt256(y);
+ if (x_scale < y_scale) {
+ x_large = gandiva::internal::IncreaseScaleBy(x_large, y_scale - x_scale);
+ } else {
+ y_large = gandiva::internal::IncreaseScaleBy(y_large, x_scale - y_scale);
+ }
+
+ if (x_large == y_large) {
+ return 0;
+ } else if (x_large < y_large) {
+ return -1;
+ } else {
+ return 1;
+ }
+}
+
} // extern "C"
diff --git a/cpp/src/gandiva/decimal_xlarge.h b/cpp/src/gandiva/decimal_xlarge.h
index c2e2dd8..2643297 100644
--- a/cpp/src/gandiva/decimal_xlarge.h
+++ b/cpp/src/gandiva/decimal_xlarge.h
@@ -35,4 +35,7 @@ void gdv_xlarge_scale_up_and_divide(int64_t x_high, uint64_t x_low, int64_t y_hi
void gdv_xlarge_mod(int64_t x_high, uint64_t x_low, int32_t x_scale, int64_t y_high,
uint64_t y_low, int32_t y_scale, int64_t* out_high,
uint64_t* out_low);
+
+int32_t gdv_xlarge_compare(int64_t x_high, uint64_t x_low, int32_t x_scale,
+ int64_t y_high, uint64_t y_low, int32_t y_scale);
}
diff --git a/cpp/src/gandiva/function_registry_arithmetic.cc b/cpp/src/gandiva/function_registry_arithmetic.cc
index ad8445b..04e9113 100644
--- a/cpp/src/gandiva/function_registry_arithmetic.cc
+++ b/cpp/src/gandiva/function_registry_arithmetic.cc
@@ -59,6 +59,12 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(multiply, decimal128),
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(divide, decimal128),
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, decimal128),
+ BINARY_RELATIONAL_SAFE_NULL_IF_NULL(equal, decimal128),
+ BINARY_RELATIONAL_SAFE_NULL_IF_NULL(not_equal, decimal128),
+ BINARY_RELATIONAL_SAFE_NULL_IF_NULL(less_than, decimal128),
+ BINARY_RELATIONAL_SAFE_NULL_IF_NULL(less_than_or_equal_to, decimal128),
+ BINARY_RELATIONAL_SAFE_NULL_IF_NULL(greater_than, decimal128),
+ BINARY_RELATIONAL_SAFE_NULL_IF_NULL(greater_than_or_equal_to, decimal128),
BINARY_RELATIONAL_BOOL_FN(equal),
BINARY_RELATIONAL_BOOL_FN(not_equal),
diff --git a/cpp/src/gandiva/precompiled/decimal_ops.cc b/cpp/src/gandiva/precompiled/decimal_ops.cc
index e13a5d8..9bf643f 100644
--- a/cpp/src/gandiva/precompiled/decimal_ops.cc
+++ b/cpp/src/gandiva/precompiled/decimal_ops.cc
@@ -423,5 +423,45 @@ BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x,
return result;
}
+int32_t CompareSameScale(const BasicDecimal128& x, const BasicDecimal128& y) {
+ if (x == y) {
+ return 0;
+ } else if (x < y) {
+ return -1;
+ } else {
+ return 1;
+ }
+}
+
+int32_t Compare(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y) {
+ int32_t delta_scale = x.scale() - y.scale();
+
+ // fast-path : both are of the same scale.
+ if (delta_scale == 0) {
+ return CompareSameScale(x.value(), y.value());
+ }
+
+ // Check if we'll need more than 256-bits after adjusting the scale.
+ bool need256 =
+ (delta_scale < 0 && x.precision() - delta_scale > DecimalTypeUtil::kMaxPrecision) ||
+ (y.precision() + delta_scale > DecimalTypeUtil::kMaxPrecision);
+ if (need256) {
+ return gdv_xlarge_compare(x.value().high_bits(), x.value().low_bits(), x.scale(),
+ y.value().high_bits(), y.value().low_bits(), y.scale());
+ } else {
+ BasicDecimal128 x_scaled;
+ BasicDecimal128 y_scaled;
+
+ if (delta_scale < 0) {
+ x_scaled = x.value().IncreaseScaleBy(-delta_scale);
+ y_scaled = y.value();
+ } else {
+ x_scaled = x.value();
+ y_scaled = y.value().IncreaseScaleBy(delta_scale);
+ }
+ return CompareSameScale(x_scaled, y_scaled);
+ }
+}
+
} // namespace decimalops
} // namespace gandiva
diff --git a/cpp/src/gandiva/precompiled/decimal_ops.h b/cpp/src/gandiva/precompiled/decimal_ops.h
index e0aea7e..19417a2 100644
--- a/cpp/src/gandiva/precompiled/decimal_ops.h
+++ b/cpp/src/gandiva/precompiled/decimal_ops.h
@@ -50,5 +50,11 @@ arrow::BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x,
const BasicDecimalScalar128& y, int32_t out_precision,
int32_t out_scale, bool* overflow);
+/// Compare two decimals. Returns :
+/// 0 if x == y
+/// 1 if x > y
+/// -1 if x < y
+int32_t Compare(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y);
+
} // namespace decimalops
} // namespace gandiva
diff --git a/cpp/src/gandiva/precompiled/decimal_ops_test.cc b/cpp/src/gandiva/precompiled/decimal_ops_test.cc
index 05289e2..f3b22ff 100644
--- a/cpp/src/gandiva/precompiled/decimal_ops_test.cc
+++ b/cpp/src/gandiva/precompiled/decimal_ops_test.cc
@@ -479,4 +479,72 @@ TEST_F(TestDecimalSql, DivideByZero) {
EXPECT_FALSE(context.has_error());
}
+TEST_F(TestDecimalSql, Compare) {
+ // x.scale == y.scale
+ EXPECT_EQ(
+ 0, decimalops::Compare(DecimalScalar128{100, 38, 6}, DecimalScalar128{100, 38, 6}));
+ EXPECT_EQ(
+ 1, decimalops::Compare(DecimalScalar128{200, 38, 6}, DecimalScalar128{100, 38, 6}));
+ EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{100, 38, 6},
+ DecimalScalar128{200, 38, 6}));
+
+ // x.scale == y.scale, with -ve.
+ EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{-100, 38, 6},
+ DecimalScalar128{-100, 38, 6}));
+ EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{-200, 38, 6},
+ DecimalScalar128{-100, 38, 6}));
+ EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{-100, 38, 6},
+ DecimalScalar128{-200, 38, 6}));
+ EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{100, 38, 6},
+ DecimalScalar128{-200, 38, 6}));
+
+ for (int32_t precision : {16, 36, 38}) {
+ // x_scale > y_scale
+ EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{10000, precision, 6},
+ DecimalScalar128{100, precision, 4}));
+ EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{20000, precision, 6},
+ DecimalScalar128{100, precision, 4}));
+ EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{10000, precision, 6},
+ DecimalScalar128{200, precision, 4}));
+
+ // x.scale > y.scale, with -ve
+ EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{-10000, precision, 6},
+ DecimalScalar128{-100, precision, 4}));
+ EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{-20000, precision, 6},
+ DecimalScalar128{-100, precision, 4}));
+ EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{-10000, precision, 6},
+ DecimalScalar128{-200, precision, 4}));
+ EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{10000, precision, 6},
+ DecimalScalar128{-200, precision, 4}));
+
+ // x.scale < y.scale
+ EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{100, precision, 4},
+ DecimalScalar128{10000, precision, 6}));
+ EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{200, precision, 4},
+ DecimalScalar128{10000, precision, 6}));
+ EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{100, precision, 4},
+ DecimalScalar128{20000, precision, 6}));
+
+ // x.scale < y.scale, with -ve
+ EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{-100, precision, 4},
+ DecimalScalar128{-10000, precision, 6}));
+ EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{-200, precision, 4},
+ DecimalScalar128{-10000, precision, 6}));
+ EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{-100, precision, 4},
+ DecimalScalar128{-20000, precision, 6}));
+ EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{100, precision, 4},
+ DecimalScalar128{-200, precision, 6}));
+ }
+
+ // large cases.
+ EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{kThirtyEight9s, 38, 6},
+ DecimalScalar128{kThirtyEight9s, 38, 6}));
+
+ EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{kThirtyEight9s, 38, 6},
+ DecimalScalar128{kThirtySix9s, 38, 4}));
+
+ EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{kThirtyEight9s, 38, 6},
+ DecimalScalar128{kThirtyEight9s, 38, 4}));
+}
+
} // namespace gandiva
diff --git a/cpp/src/gandiva/precompiled/decimal_wrapper.cc b/cpp/src/gandiva/precompiled/decimal_wrapper.cc
index d5c919e..69a3b70 100644
--- a/cpp/src/gandiva/precompiled/decimal_wrapper.cc
+++ b/cpp/src/gandiva/precompiled/decimal_wrapper.cc
@@ -86,4 +86,15 @@ void mod_internal_decimal128_decimal128(int64_t context, int64_t x_high, uint64_
*out_low = out.low_bits();
}
+FORCE_INLINE
+int32_t compare_internal_decimal128_decimal128(int64_t x_high, uint64_t x_low,
+ int32_t x_precision, int32_t x_scale,
+ int64_t y_high, uint64_t y_low,
+ int32_t y_precision, int32_t y_scale) {
+ gandiva::BasicDecimalScalar128 x(x_high, x_low, x_precision, x_scale);
+ gandiva::BasicDecimalScalar128 y(y_high, y_low, y_precision, y_scale);
+
+ return gandiva::decimalops::Compare(x, y);
+}
+
} // extern "C"
diff --git a/cpp/src/gandiva/tests/decimal_test.cc b/cpp/src/gandiva/tests/decimal_test.cc
index da93b0e..08435e4 100644
--- a/cpp/src/gandiva/tests/decimal_test.cc
+++ b/cpp/src/gandiva/tests/decimal_test.cc
@@ -27,6 +27,7 @@
#include "gandiva/tests/test_util.h"
#include "gandiva/tree_expr_builder.h"
+using arrow::boolean;
using arrow::Decimal128;
namespace gandiva {
@@ -234,4 +235,66 @@ TEST_F(TestDecimal, TestIfElse) {
EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
}
+TEST_F(TestDecimal, TestCompare) {
+ // schema for input fields
+ constexpr int32_t precision = 36;
+ constexpr int32_t scale = 18;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto field_a = field("a", decimal_type);
+ auto field_b = field("b", decimal_type);
+ auto schema = arrow::schema({field_a, field_b});
+
+ // build expressions
+ auto exprs = std::vector<ExpressionPtr>{
+ TreeExprBuilder::MakeExpression("equal", {field_a, field_b},
+ field("res_eq", boolean())),
+ TreeExprBuilder::MakeExpression("not_equal", {field_a, field_b},
+ field("res_ne", boolean())),
+ TreeExprBuilder::MakeExpression("less_than", {field_a, field_b},
+ field("res_lt", boolean())),
+ TreeExprBuilder::MakeExpression("less_than_or_equal_to", {field_a, field_b},
+ field("res_le", boolean())),
+ TreeExprBuilder::MakeExpression("greater_than", {field_a, field_b},
+ field("res_gt", boolean())),
+ TreeExprBuilder::MakeExpression("greater_than_or_equal_to", {field_a, field_b},
+ field("res_ge", boolean())),
+ };
+
+ // Build a projector for the expression.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector);
+ DCHECK_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"1", "2", "3", "-4"}, scale),
+ {true, true, true, true});
+ auto array_b =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"1", "3", "2", "-3"}, scale),
+ {true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ DCHECK_OK(status);
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({true, false, false, false}),
+ outputs[0]); // equal
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({false, true, true, true}),
+ outputs[1]); // not_equal
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({false, true, false, true}),
+ outputs[2]); // less_than
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({true, true, false, true}),
+ outputs[3]); // less_than_or_equal_to
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({false, false, true, false}),
+ outputs[4]); // greater_than
+ EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({true, false, true, false}),
+ outputs[5]); // greater_than_or_equal_to
+}
+
} // namespace gandiva