You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2019/01/08 15:32:50 UTC
[arrow] branch master updated: ARROW-3701: [Gandiva] add op for
decimal 128
This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new d6ddcbf ARROW-3701: [Gandiva] add op for decimal 128
d6ddcbf is described below
commit d6ddcbf1566be6afb0e123589adfb5e5d60e3a4c
Author: Pindikura Ravindra <ra...@dremio.com>
AuthorDate: Tue Jan 8 09:32:38 2019 -0600
ARROW-3701: [Gandiva] add op for decimal 128
The code changes are complete. However, the perf in the non-fast code path is slow - I'll debug and fix that.
Author: Pindikura Ravindra <ra...@dremio.com>
Author: praveenbingo <pr...@dremio.com>
Closes #2942 from pravindra/decimal2 and squashes the following commits:
0f7e78a76 <Pindikura Ravindra> ARROW-3701: off gandiva tests in py 2.7
613524602 <Pindikura Ravindra> ARROW-3701: fix format error
c0fddfbc6 <Pindikura Ravindra> ARROW-3701: fix python unresolved symbol
db8581162 <Pindikura Ravindra> ARROW-3701: added a comment regarding structs.
194c4377a <Pindikura Ravindra> ARROW-3701: revert surefire version
5d07b79e2 <Pindikura Ravindra> ARROW-3701: Address review comments
36691c1c7 <Pindikura Ravindra> ARROW-3701: add benchmark for large decimals
75f7ac9d4 <Pindikura Ravindra> ARROW-3701: misc cleanups
59db4603d <Pindikura Ravindra> ARROW-3701: Fix java checkstyle issue
8a227ec9c <Pindikura Ravindra> ARROW-3701: Workaround for jni JIT issue
9cbd4ab59 <Pindikura Ravindra> ARROW-3701: switch to surefire 2.19 for dbg
ecaff4631 <Pindikura Ravindra> ARROW-3701: Enable decimal tests
54a210511 <praveenbingo> ARROW-3701: Support for decimal literal and null
b76a3ec1b <Pindikura Ravindra> ARROW-3701: First decimal function
---
.travis.yml | 3 +-
cpp/src/arrow/util/decimal-test.cc | 104 ++++++
cpp/src/arrow/util/decimal.cc | 97 ++++-
cpp/src/arrow/util/decimal.h | 19 +
cpp/src/gandiva/CMakeLists.txt | 4 +
cpp/src/gandiva/arrow.h | 11 +
cpp/src/gandiva/decimal_full.h | 75 ++++
cpp/src/gandiva/decimal_ir.cc | 405 +++++++++++++++++++++
cpp/src/gandiva/decimal_ir.h | 171 +++++++++
cpp/src/gandiva/decimal_type_util.cc | 80 ++++
cpp/src/gandiva/decimal_type_util.h | 90 +++++
cpp/src/gandiva/decimal_type_util_test.cc | 58 +++
cpp/src/gandiva/engine.cc | 9 +-
cpp/src/gandiva/engine.h | 2 +
cpp/src/gandiva/expression_registry.cc | 4 +-
cpp/src/gandiva/function_ir_builder.cc | 81 +++++
cpp/src/gandiva/function_ir_builder.h | 64 ++++
cpp/src/gandiva/function_registry.cc | 19 +-
cpp/src/gandiva/function_registry_arithmetic.cc | 2 +
cpp/src/gandiva/function_registry_common.h | 1 +
cpp/src/gandiva/function_signature.h | 18 +-
cpp/src/gandiva/jni/CMakeLists.txt | 2 +-
cpp/src/gandiva/jni/expression_registry_helper.cc | 7 +-
cpp/src/gandiva/jni/jni_common.cc | 6 +
cpp/src/gandiva/literal_holder.h | 5 +-
cpp/src/gandiva/llvm_generator.cc | 168 ++++++---
cpp/src/gandiva/llvm_generator.h | 9 +-
cpp/src/gandiva/llvm_types.cc | 1 +
cpp/src/gandiva/llvm_types.h | 25 +-
cpp/src/gandiva/lvalue.h | 35 +-
cpp/src/gandiva/precompiled/CMakeLists.txt | 12 +-
cpp/src/gandiva/precompiled/decimal_ops.cc | 219 +++++++++++
.../decimal_ops.h} | 20 +-
cpp/src/gandiva/precompiled/decimal_ops_test.cc | 75 ++++
cpp/src/gandiva/precompiled/decimal_wrapper.cc | 43 +++
cpp/src/gandiva/projector.cc | 6 +-
cpp/src/gandiva/proto/Types.proto | 8 +
cpp/src/gandiva/tests/CMakeLists.txt | 8 +-
cpp/src/gandiva/tests/decimal_single_test.cc | 224 ++++++++++++
cpp/src/gandiva/tests/decimal_test.cc | 237 ++++++++++++
cpp/src/gandiva/tests/generate_data.h | 20 +
cpp/src/gandiva/tests/micro_benchmarks.cc | 126 ++++++-
cpp/src/gandiva/tests/test_util.h | 14 +
cpp/src/gandiva/tests/timed_evaluate.h | 4 +-
cpp/src/gandiva/tree_expr_builder.cc | 10 +
cpp/src/gandiva/tree_expr_builder.h | 3 +
cpp/valgrind.supp | 6 +
java/gandiva/pom.xml | 7 +-
.../gandiva/evaluator/ConfigurationBuilder.java | 32 --
.../arrow/gandiva/evaluator/DecimalTypeUtil.java | 86 +++++
.../gandiva/evaluator/ExpressionRegistry.java | 5 +-
.../org/apache/arrow/gandiva/evaluator/Filter.java | 16 +-
.../apache/arrow/gandiva/evaluator/JniLoader.java | 148 ++++++++
.../apache/arrow/gandiva/evaluator/JniWrapper.java | 93 +----
.../apache/arrow/gandiva/evaluator/Projector.java | 20 +-
.../arrow/gandiva/expression/DecimalNode.java | 54 +++
.../arrow/gandiva/expression/TreeBuilder.java | 4 +
.../arrow/gandiva/evaluator/BaseEvaluatorTest.java | 15 +
.../gandiva/evaluator/DecimalTypeUtilTest.java | 89 +++++
.../gandiva/evaluator/ProjectorDecimalTest.java | 157 ++++++++
python/pyarrow/gandiva.pyx | 10 +
61 files changed, 3103 insertions(+), 243 deletions(-)
diff --git a/.travis.yml b/.travis.yml
index ffbb691..8532cc7 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -121,7 +121,6 @@ matrix:
- ARROW_TRAVIS_COVERAGE=1
- ARROW_TRAVIS_PYTHON_DOCS=1
- ARROW_TRAVIS_PYTHON_JVM=1
- - ARROW_TRAVIS_PYTHON_GANDIVA=1
- ARROW_TRAVIS_OPTIONAL_INSTALL=1
- ARROW_BUILD_WARNING_LEVEL=CHECKIN
# TODO(wesm): Run the benchmarks outside of Travis
@@ -138,6 +137,8 @@ matrix:
- export PLASMA_VALGRIND=0
- $TRAVIS_BUILD_DIR/ci/travis_script_python.sh 2.7 || travis_terminate 1
- export PLASMA_VALGRIND=1
+ # Gandiva tests are not enabled with python 2.7
+ - ARROW_TRAVIS_PYTHON_GANDIVA=1
- $TRAVIS_BUILD_DIR/ci/travis_script_python.sh 3.6 || travis_terminate 1
- $TRAVIS_BUILD_DIR/ci/travis_upload_cpp_coverage.sh
- name: "[OS X] C++ w/ XCode 8.3"
diff --git a/cpp/src/arrow/util/decimal-test.cc b/cpp/src/arrow/util/decimal-test.cc
index 5925d98..73ac48c 100644
--- a/cpp/src/arrow/util/decimal-test.cc
+++ b/cpp/src/arrow/util/decimal-test.cc
@@ -466,4 +466,108 @@ TEST(Decimal128Test, TestToInteger) {
ASSERT_RAISES(Invalid, invalid_int64.ToInteger(&out2));
}
+TEST(Decimal128Test, GetWholeAndFraction) {
+ Decimal128 value("123456");
+ Decimal128 whole;
+ Decimal128 fraction;
+ int32_t out;
+
+ value.GetWholeAndFraction(0, &whole, &fraction);
+ ASSERT_OK(whole.ToInteger(&out));
+ ASSERT_EQ(123456, out);
+ ASSERT_OK(fraction.ToInteger(&out));
+ ASSERT_EQ(0, out);
+
+ value.GetWholeAndFraction(1, &whole, &fraction);
+ ASSERT_OK(whole.ToInteger(&out));
+ ASSERT_EQ(12345, out);
+ ASSERT_OK(fraction.ToInteger(&out));
+ ASSERT_EQ(6, out);
+
+ value.GetWholeAndFraction(5, &whole, &fraction);
+ ASSERT_OK(whole.ToInteger(&out));
+ ASSERT_EQ(1, out);
+ ASSERT_OK(fraction.ToInteger(&out));
+ ASSERT_EQ(23456, out);
+
+ value.GetWholeAndFraction(7, &whole, &fraction);
+ ASSERT_OK(whole.ToInteger(&out));
+ ASSERT_EQ(0, out);
+ ASSERT_OK(fraction.ToInteger(&out));
+ ASSERT_EQ(123456, out);
+}
+
+TEST(Decimal128Test, GetWholeAndFractionNegative) {
+ Decimal128 value("-123456");
+ Decimal128 whole;
+ Decimal128 fraction;
+ int32_t out;
+
+ value.GetWholeAndFraction(0, &whole, &fraction);
+ ASSERT_OK(whole.ToInteger(&out));
+ ASSERT_EQ(-123456, out);
+ ASSERT_OK(fraction.ToInteger(&out));
+ ASSERT_EQ(0, out);
+
+ value.GetWholeAndFraction(1, &whole, &fraction);
+ ASSERT_OK(whole.ToInteger(&out));
+ ASSERT_EQ(-12345, out);
+ ASSERT_OK(fraction.ToInteger(&out));
+ ASSERT_EQ(-6, out);
+
+ value.GetWholeAndFraction(5, &whole, &fraction);
+ ASSERT_OK(whole.ToInteger(&out));
+ ASSERT_EQ(-1, out);
+ ASSERT_OK(fraction.ToInteger(&out));
+ ASSERT_EQ(-23456, out);
+
+ value.GetWholeAndFraction(7, &whole, &fraction);
+ ASSERT_OK(whole.ToInteger(&out));
+ ASSERT_EQ(0, out);
+ ASSERT_OK(fraction.ToInteger(&out));
+ ASSERT_EQ(-123456, out);
+}
+
+TEST(Decimal128Test, IncreaseScale) {
+ Decimal128 result;
+ int32_t out;
+
+ result = Decimal128("1234").IncreaseScaleBy(3);
+ ASSERT_OK(result.ToInteger(&out));
+ ASSERT_EQ(1234000, out);
+
+ result = Decimal128("-1234").IncreaseScaleBy(3);
+ ASSERT_OK(result.ToInteger(&out));
+ ASSERT_EQ(-1234000, out);
+}
+
+TEST(Decimal128Test, ReduceScaleAndRound) {
+ Decimal128 result;
+ int32_t out;
+
+ result = Decimal128("123456").ReduceScaleBy(1, false);
+ ASSERT_OK(result.ToInteger(&out));
+ ASSERT_EQ(12345, out);
+
+ result = Decimal128("123456").ReduceScaleBy(1, true);
+ ASSERT_OK(result.ToInteger(&out));
+ ASSERT_EQ(12346, out);
+
+ result = Decimal128("123451").ReduceScaleBy(1, true);
+ ASSERT_OK(result.ToInteger(&out));
+ ASSERT_EQ(12345, out);
+
+ result = Decimal128("-123789").ReduceScaleBy(2, true);
+ ASSERT_OK(result.ToInteger(&out));
+ ASSERT_EQ(-1238, out);
+
+ result = Decimal128("-123749").ReduceScaleBy(2, true);
+ ASSERT_OK(result.ToInteger(&out));
+ ASSERT_EQ(-1237, out);
+
+ result = Decimal128("-123750").ReduceScaleBy(2, true);
+ ASSERT_OK(result.ToInteger(&out));
+ ASSERT_EQ(-1238, out);
+}
+
} // namespace arrow
diff --git a/cpp/src/arrow/util/decimal.cc b/cpp/src/arrow/util/decimal.cc
index c980e2a..8d6c069 100644
--- a/cpp/src/arrow/util/decimal.cc
+++ b/cpp/src/arrow/util/decimal.cc
@@ -39,7 +39,7 @@ using internal::SafeLeftShift;
using internal::SafeSignedAdd;
static const Decimal128 ScaleMultipliers[] = {
- Decimal128(0LL),
+ Decimal128(1LL),
Decimal128(10LL),
Decimal128(100LL),
Decimal128(1000LL),
@@ -79,6 +79,47 @@ static const Decimal128 ScaleMultipliers[] = {
Decimal128(542101086242752217LL, 68739955140067328ULL),
Decimal128(5421010862427522170LL, 687399551400673280ULL)};
+static const Decimal128 ScaleMultipliersHalf[] = {
+ Decimal128(0ULL),
+ Decimal128(5ULL),
+ Decimal128(50ULL),
+ Decimal128(500ULL),
+ Decimal128(5000ULL),
+ Decimal128(50000ULL),
+ Decimal128(500000ULL),
+ Decimal128(5000000ULL),
+ Decimal128(50000000ULL),
+ Decimal128(500000000ULL),
+ Decimal128(5000000000ULL),
+ Decimal128(50000000000ULL),
+ Decimal128(500000000000ULL),
+ Decimal128(5000000000000ULL),
+ Decimal128(50000000000000ULL),
+ Decimal128(500000000000000ULL),
+ Decimal128(5000000000000000ULL),
+ Decimal128(50000000000000000ULL),
+ Decimal128(500000000000000000ULL),
+ Decimal128(5000000000000000000ULL),
+ Decimal128(2LL, 13106511852580896768ULL),
+ Decimal128(27LL, 1937910009842106368ULL),
+ Decimal128(271LL, 932356024711512064ULL),
+ Decimal128(2710LL, 9323560247115120640ULL),
+ Decimal128(27105LL, 1001882102603448320ULL),
+ Decimal128(271050LL, 10018821026034483200ULL),
+ Decimal128(2710505LL, 7954489891797073920ULL),
+ Decimal128(27105054LL, 5757922623132532736ULL),
+ Decimal128(271050543LL, 2238994010196672512ULL),
+ Decimal128(2710505431LL, 3943196028257173504ULL),
+ Decimal128(27105054312LL, 2538472135152631808ULL),
+ Decimal128(271050543121LL, 6937977277816766464ULL),
+ Decimal128(2710505431213LL, 14039540557039009792ULL),
+ Decimal128(27105054312137LL, 11268197054423236608ULL),
+ Decimal128(271050543121376LL, 2001506101975056384ULL),
+ Decimal128(2710505431213761LL, 1568316946041012224ULL),
+ Decimal128(27105054312137610LL, 15683169460410122240ULL),
+ Decimal128(271050543121376108LL, 9257742014424809472ULL),
+ Decimal128(2710505431213761085LL, 343699775700336640ULL)};
+
static constexpr uint64_t kIntMask = 0xFFFFFFFF;
static constexpr auto kCarryBit = static_cast<uint64_t>(1) << static_cast<uint64_t>(32);
@@ -888,6 +929,60 @@ Status Decimal128::Rescale(int32_t original_scale, int32_t new_scale,
return Status::OK();
}
+void Decimal128::GetWholeAndFraction(int scale, Decimal128* whole,
+ Decimal128* fraction) const {
+ DCHECK_GE(scale, 0);
+ DCHECK_LE(scale, 38);
+
+ Decimal128 multiplier(ScaleMultipliers[scale]);
+ DCHECK_OK(Divide(multiplier, whole, fraction));
+}
+
+const Decimal128& Decimal128::GetScaleMultiplier(int32_t scale) {
+ DCHECK_GE(scale, 0);
+ DCHECK_LE(scale, 38);
+
+ return ScaleMultipliers[scale];
+}
+
+Decimal128 Decimal128::IncreaseScaleBy(int32_t increase_by) const {
+ DCHECK_GE(increase_by, 0);
+ DCHECK_LE(increase_by, 38);
+
+ return (*this) * ScaleMultipliers[increase_by];
+}
+
+Decimal128 Decimal128::ReduceScaleBy(int32_t reduce_by, bool round) const {
+ DCHECK_GE(reduce_by, 0);
+ DCHECK_LE(reduce_by, 38);
+
+ Decimal128 divisor(ScaleMultipliers[reduce_by]);
+ Decimal128 result;
+ Decimal128 remainder;
+ DCHECK_OK(Divide(divisor, &result, &remainder));
+ if (round) {
+ auto divisor_half = ScaleMultipliersHalf[reduce_by];
+ if (remainder.Abs() >= divisor_half) {
+ if (result > 0) {
+ result += 1;
+ } else {
+ result -= 1;
+ }
+ }
+ }
+ return result;
+}
+
+int32_t Decimal128::CountLeadingBinaryZeros() const {
+ DCHECK_GE(*this, Decimal128(0));
+
+ if (high_bits_ == 0) {
+ return BitUtil::CountLeadingZeros(low_bits_) + 64;
+ } else {
+ return BitUtil::CountLeadingZeros(static_cast<uint64_t>(high_bits_));
+ }
+}
+
// Helper function used by Decimal128::FromBigEndian
static inline uint64_t UInt64FromBigEndian(const uint8_t* bytes, int32_t length) {
// We don't bounds check the length here because this is called by
diff --git a/cpp/src/arrow/util/decimal.h b/cpp/src/arrow/util/decimal.h
index f59a4a4..5734fa0 100644
--- a/cpp/src/arrow/util/decimal.h
+++ b/cpp/src/arrow/util/decimal.h
@@ -139,9 +139,28 @@ class ARROW_EXPORT Decimal128 {
/// \return error status if the length is an invalid value
static Status FromBigEndian(const uint8_t* data, int32_t length, Decimal128* out);
+ /// \brief seperate the integer and fractional parts for the given scale.
+ void GetWholeAndFraction(int32_t scale, Decimal128* whole, Decimal128* fraction) const;
+
+ /// \brief Scale multiplier for given scale value.
+ static const Decimal128& GetScaleMultiplier(int32_t scale);
+
/// \brief Convert Decimal128 from one scale to another
Status Rescale(int32_t original_scale, int32_t new_scale, Decimal128* out) const;
+ /// \brief Scale up.
+ Decimal128 IncreaseScaleBy(int32_t increase_by) const;
+
+ /// \brief Scale down.
+ /// - If 'round' is true, the right-most digits are dropped and the result value is
+ /// rounded up (+1 for +ve, -1 for -ve) based on the value of the dropped digits
+ /// (>= 10^reduce_by / 2).
+ /// - If 'round' is false, the right-most digits are simply dropped.
+ Decimal128 ReduceScaleBy(int32_t reduce_by, bool round = true) const;
+
+ /// \brief count the number of leading binary zeroes.
+ int32_t CountLeadingBinaryZeros() const;
+
/// \brief Convert to a signed integer
template <typename T, typename = internal::EnableIfIsOneOf<T, int32_t, int64_t>>
Status ToInteger(T* out) const {
diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt
index 90fe7cf..e743b0e 100644
--- a/cpp/src/gandiva/CMakeLists.txt
+++ b/cpp/src/gandiva/CMakeLists.txt
@@ -46,6 +46,8 @@ set(SRC_FILES annotator.cc
bitmap_accumulator.cc
configuration.cc
context_helper.cc
+ decimal_ir.cc
+ decimal_type_util.cc
engine.cc
date_utils.cc
expr_decomposer.cc
@@ -54,6 +56,7 @@ set(SRC_FILES annotator.cc
expression_registry.cc
exported_funcs_registry.cc
filter.cc
+ function_ir_builder.cc
function_registry.cc
function_registry_arithmetic.cc
function_registry_datetime.cc
@@ -175,6 +178,7 @@ ADD_GANDIVA_TEST(lru_cache_test)
ADD_GANDIVA_TEST(to_date_holder_test)
ADD_GANDIVA_TEST(simple_arena_test)
ADD_GANDIVA_TEST(like_holder_test)
+ADD_GANDIVA_TEST(decimal_type_util_test)
if (ARROW_GANDIVA_JAVA)
add_subdirectory(jni)
diff --git a/cpp/src/gandiva/arrow.h b/cpp/src/gandiva/arrow.h
index ea28352..cc2bd9a 100644
--- a/cpp/src/gandiva/arrow.h
+++ b/cpp/src/gandiva/arrow.h
@@ -35,6 +35,9 @@ using ArrayPtr = std::shared_ptr<arrow::Array>;
using DataTypePtr = std::shared_ptr<arrow::DataType>;
using DataTypeVector = std::vector<DataTypePtr>;
+using Decimal128TypePtr = std::shared_ptr<arrow::Decimal128Type>;
+using Decimal128TypeVector = std::vector<Decimal128TypePtr>;
+
using FieldPtr = std::shared_ptr<arrow::Field>;
using FieldVector = std::vector<FieldPtr>;
@@ -48,6 +51,14 @@ using ArrayDataVector = std::vector<ArrayDataPtr>;
using Status = arrow::Status;
using StatusCode = arrow::StatusCode;
+static inline bool is_decimal_128(DataTypePtr type) {
+ if (type->id() == arrow::Type::DECIMAL) {
+ auto decimal_type = arrow::internal::checked_cast<arrow::DecimalType*>(type.get());
+ return decimal_type->byte_width() == 16;
+ } else {
+ return false;
+ }
+}
} // namespace gandiva
#endif // GANDIVA_EXPR_ARROW_H
diff --git a/cpp/src/gandiva/decimal_full.h b/cpp/src/gandiva/decimal_full.h
new file mode 100644
index 0000000..3b84da1
--- /dev/null
+++ b/cpp/src/gandiva/decimal_full.h
@@ -0,0 +1,75 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifndef DECIMAL_FULL_H
+#define DECIMAL_FULL_H
+
+#include <cstdint>
+#include <iostream>
+#include <string>
+#include "arrow/util/decimal.h"
+
+namespace gandiva {
+
+using Decimal128 = arrow::Decimal128;
+
+/// Represents a 128-bit decimal value along with its precision and scale.
+class Decimal128Full {
+ public:
+ Decimal128Full(int64_t high_bits, uint64_t low_bits, int32_t precision, int32_t scale)
+ : value_(high_bits, low_bits), precision_(precision), scale_(scale) {}
+
+ Decimal128Full(std::string value, int32_t precision, int32_t scale)
+ : value_(value), precision_(precision), scale_(scale) {}
+
+ Decimal128Full(const Decimal128& value, int32_t precision, int32_t scale)
+ : value_(value), precision_(precision), scale_(scale) {}
+
+ Decimal128Full(int32_t precision, int32_t scale)
+ : value_(0), precision_(precision), scale_(scale) {}
+
+ uint32_t scale() const { return scale_; }
+
+ uint32_t precision() const { return precision_; }
+
+ const arrow::Decimal128& value() const { return value_; }
+
+ inline std::string ToString() const {
+ return value_.ToString(0) + "," + std::to_string(precision_) + "," +
+ std::to_string(scale_);
+ }
+
+ friend std::ostream& operator<<(std::ostream& os, const Decimal128Full& dec) {
+ os << dec.ToString();
+ return os;
+ }
+
+ private:
+ Decimal128 value_;
+
+ int32_t precision_;
+ int32_t scale_;
+};
+
+inline bool operator==(const Decimal128Full& left, const Decimal128Full& right) {
+ return left.value() == right.value() && left.precision() == right.precision() &&
+ left.scale() == right.scale();
+}
+
+} // namespace gandiva
+
+#endif // DECIMAL_FULL_H
diff --git a/cpp/src/gandiva/decimal_ir.cc b/cpp/src/gandiva/decimal_ir.cc
new file mode 100644
index 0000000..38b35a6
--- /dev/null
+++ b/cpp/src/gandiva/decimal_ir.cc
@@ -0,0 +1,405 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <sstream>
+#include <utility>
+
+#include "arrow/status.h"
+#include "gandiva/decimal_ir.h"
+#include "gandiva/decimal_type_util.h"
+
+// Algorithms adapted from Apache Impala
+
+namespace gandiva {
+
+#define ADD_TRACE_32(msg, value) \
+ if (enable_ir_traces_) { \
+ AddTrace32(msg, value); \
+ }
+#define ADD_TRACE_128(msg, value) \
+ if (enable_ir_traces_) { \
+ AddTrace128(msg, value); \
+ }
+
+const char* DecimalIR::kScaleMultipliersName = "gandivaScaleMultipliers";
+
+/// Populate globals required by decimal IR.
+/// TODO: can this be done just once ?
+void DecimalIR::AddGlobals(Engine* engine) {
+ auto types = engine->types();
+
+ // populate vector : [ 1, 10, 100, 1000, ..]
+ std::string value = "1";
+ std::vector<llvm::Constant*> scale_multipliers;
+ for (int i = 0; i < DecimalTypeUtil::kMaxPrecision + 1; ++i) {
+ auto multiplier =
+ llvm::ConstantInt::get(llvm::Type::getInt128Ty(*engine->context()), value, 10);
+ scale_multipliers.push_back(multiplier);
+ value.append("0");
+ }
+
+ auto array_type =
+ llvm::ArrayType::get(types->i128_type(), DecimalTypeUtil::kMaxPrecision + 1);
+ auto initializer = llvm::ConstantArray::get(
+ array_type, llvm::ArrayRef<llvm::Constant*>(scale_multipliers));
+
+ auto globalScaleMultipliers = new llvm::GlobalVariable(
+ *engine->module(), array_type, true /*constant*/,
+ llvm::GlobalValue::LinkOnceAnyLinkage, initializer, kScaleMultipliersName);
+ globalScaleMultipliers->setAlignment(16);
+}
+
+// Lookup intrinsic functions
+void DecimalIR::InitializeIntrinsics() {
+ sadd_with_overflow_fn_ = llvm::Intrinsic::getDeclaration(
+ module(), llvm::Intrinsic::sadd_with_overflow, types()->i128_type());
+ DCHECK_NE(sadd_with_overflow_fn_, nullptr);
+
+ smul_with_overflow_fn_ = llvm::Intrinsic::getDeclaration(
+ module(), llvm::Intrinsic::smul_with_overflow, types()->i128_type());
+ DCHECK_NE(smul_with_overflow_fn_, nullptr);
+
+ i128_with_overflow_struct_type_ =
+ sadd_with_overflow_fn_->getFunctionType()->getReturnType();
+}
+
+// CPP: return kScaleMultipliers[scale]
+llvm::Value* DecimalIR::GetScaleMultiplier(llvm::Value* scale) {
+ auto const_array = module()->getGlobalVariable(kScaleMultipliersName);
+ auto ptr = ir_builder()->CreateGEP(const_array, {types()->i32_constant(0), scale});
+ return ir_builder()->CreateLoad(ptr);
+}
+
+// CPP: x <= y ? y : x
+llvm::Value* DecimalIR::GetHigherScale(llvm::Value* x_scale, llvm::Value* y_scale) {
+ llvm::Value* le = ir_builder()->CreateICmpSLE(x_scale, y_scale);
+ return ir_builder()->CreateSelect(le, y_scale, x_scale);
+}
+
+// CPP: return (increase_scale_by <= 0) ?
+// in_value : in_value * GetScaleMultiplier(increase_scale_by)
+llvm::Value* DecimalIR::IncreaseScale(llvm::Value* in_value,
+ llvm::Value* increase_scale_by) {
+ llvm::Value* le_zero =
+ ir_builder()->CreateICmpSLE(increase_scale_by, types()->i32_constant(0));
+ // then block
+ auto then_lambda = [&] { return in_value; };
+
+ // else block
+ auto else_lambda = [&] {
+ llvm::Value* multiplier = GetScaleMultiplier(increase_scale_by);
+ return ir_builder()->CreateMul(in_value, multiplier);
+ };
+
+ return BuildIfElse(le_zero, types()->i128_type(), then_lambda, else_lambda);
+}
+
+// CPP: return (increase_scale_by <= 0) ?
+// {in_value,false} : {in_value * GetScaleMultiplier(increase_scale_by),true}
+//
+// The return value also indicates if there was an overflow while increasing the scale.
+DecimalIR::ValueWithOverflow DecimalIR::IncreaseScaleWithOverflowCheck(
+ llvm::Value* in_value, llvm::Value* increase_scale_by) {
+ llvm::Value* le_zero =
+ ir_builder()->CreateICmpSLE(increase_scale_by, types()->i32_constant(0));
+
+ // then block
+ auto then_lambda = [&] {
+ ValueWithOverflow ret{in_value, types()->false_constant()};
+ return ret.AsStruct(this);
+ };
+
+ // else block
+ auto else_lambda = [&] {
+ llvm::Value* multiplier = GetScaleMultiplier(increase_scale_by);
+ return ir_builder()->CreateCall(smul_with_overflow_fn_, {in_value, multiplier});
+ };
+
+ auto ir_struct =
+ BuildIfElse(le_zero, i128_with_overflow_struct_type_, then_lambda, else_lambda);
+ return ValueWithOverflow::MakeFromStruct(this, ir_struct);
+}
+
+// CPP: return (reduce_scale_by <= 0) ?
+// in_value : in_value / GetScaleMultiplier(reduce_scale_by)
+//
+// ReduceScale cannot cause an overflow.
+llvm::Value* DecimalIR::ReduceScale(llvm::Value* in_value, llvm::Value* reduce_scale_by) {
+ auto le_zero = ir_builder()->CreateICmpSLE(reduce_scale_by, types()->i32_constant(0));
+ // then block
+ auto then_lambda = [&] { return in_value; };
+
+ // else block
+ auto else_lambda = [&] {
+ // TODO : handle rounding.
+ llvm::Value* multiplier = GetScaleMultiplier(reduce_scale_by);
+ return ir_builder()->CreateSDiv(in_value, multiplier);
+ };
+
+ return BuildIfElse(le_zero, types()->i128_type(), then_lambda, else_lambda);
+}
+
+/// @brief Fast-path for add
+/// Adjust x and y to the same scale, and add them.
+llvm::Value* DecimalIR::AddFastPath(const ValueFull& x, const ValueFull& y) {
+ auto higher_scale = GetHigherScale(x.scale(), y.scale());
+ ADD_TRACE_32("AddFastPath : higher_scale", higher_scale);
+
+ // CPP : x_scaled = IncreaseScale(x_value, higher_scale - x_scale)
+ auto x_delta = ir_builder()->CreateSub(higher_scale, x.scale());
+ auto x_scaled = IncreaseScale(x.value(), x_delta);
+ ADD_TRACE_128("AddFastPath : x_scaled", x_scaled);
+
+ // CPP : y_scaled = IncreaseScale(y_value, higher_scale - y_scale)
+ auto y_delta = ir_builder()->CreateSub(higher_scale, y.scale());
+ auto y_scaled = IncreaseScale(y.value(), y_delta);
+ ADD_TRACE_128("AddFastPath : y_scaled", y_scaled);
+
+ auto sum = ir_builder()->CreateAdd(x_scaled, y_scaled);
+ ADD_TRACE_128("AddFastPath : sum", sum);
+ return sum;
+}
+
+// @brief Add with overflow check.
+/// Adjust x and y to the same scale, add them, and reduce sum to output scale.
+/// If there is an overflow, the sum is set to 0.
+DecimalIR::ValueWithOverflow DecimalIR::AddWithOverflowCheck(const ValueFull& x,
+ const ValueFull& y,
+ const ValueFull& out) {
+ auto higher_scale = GetHigherScale(x.scale(), y.scale());
+ ADD_TRACE_32("AddWithOverflowCheck : higher_scale", higher_scale);
+
+ // CPP : x_scaled = IncreaseScale(x_value, higher_scale - x.scale())
+ auto x_delta = ir_builder()->CreateSub(higher_scale, x.scale());
+ auto x_scaled = IncreaseScaleWithOverflowCheck(x.value(), x_delta);
+ ADD_TRACE_128("AddWithOverflowCheck : x_scaled", x_scaled.value());
+
+ // CPP : y_scaled = IncreaseScale(y_value, higher_scale - y_scale)
+ auto y_delta = ir_builder()->CreateSub(higher_scale, y.scale());
+ auto y_scaled = IncreaseScaleWithOverflowCheck(y.value(), y_delta);
+ ADD_TRACE_128("AddWithOverflowCheck : y_scaled", y_scaled.value());
+
+ // CPP : sum = x_scaled + y_scaled
+ auto sum_ir_struct = ir_builder()->CreateCall(sadd_with_overflow_fn_,
+ {x_scaled.value(), y_scaled.value()});
+ auto sum = ValueWithOverflow::MakeFromStruct(this, sum_ir_struct);
+ ADD_TRACE_128("AddWithOverflowCheck : sum", sum.value());
+
+ // CPP : overflow ? 0 : sum / GetScaleMultiplier(max_scale - out_scale)
+ auto overflow = GetCombinedOverflow({x_scaled, y_scaled, sum});
+ ADD_TRACE_32("AddWithOverflowCheck : overflow", overflow);
+ auto then_lambda = [&] {
+ // if there is an overflow, the value returned won't be used. so, save the division.
+ return types()->i128_constant(0);
+ };
+ auto else_lambda = [&] {
+ auto reduce_scale_by = ir_builder()->CreateSub(higher_scale, out.scale());
+ return ReduceScale(sum.value(), reduce_scale_by);
+ };
+ auto sum_descaled =
+ BuildIfElse(overflow, types()->i128_type(), then_lambda, else_lambda);
+ return ValueWithOverflow(sum_descaled, overflow);
+}
+
+// This is pretty complex, so use CPP fns.
+llvm::Value* DecimalIR::AddLarge(const ValueFull& x, const ValueFull& y,
+ const ValueFull& out) {
+ std::vector<llvm::Value*> args;
+
+ auto x_split = ValueSplit::MakeFromInt128(this, x.value());
+ args.push_back(x_split.high());
+ args.push_back(x_split.low());
+ args.push_back(x.precision());
+ args.push_back(x.scale());
+
+ auto y_split = ValueSplit::MakeFromInt128(this, y.value());
+ args.push_back(y_split.high());
+ args.push_back(y_split.low());
+ args.push_back(y.precision());
+ args.push_back(y.scale());
+
+ args.push_back(out.precision());
+ args.push_back(out.scale());
+
+ auto split = ir_builder()->CreateCall(
+ module()->getFunction("add_large_decimal128_decimal128"), args);
+
+ auto sum = ValueSplit::MakeFromStruct(this, split).AsInt128(this);
+ ADD_TRACE_128("AddLarge : sum", sum);
+ return sum;
+}
+
+/// The output scale/precision cannot be arbitary values. The algo here depends on them
+/// to be the same as computed in DecimalTypeSql.
+/// TODO: enforce this.
+Status DecimalIR::BuildAdd() {
+ // Create fn prototype :
+ // int128_t
+ // add_decimal128_decimal128(int128_t x_value, int32_t x_precision, int32_t x_scale,
+ // int128_t y_value, int32_t y_precision, int32_t y_scale
+ // int32_t out_precision, int32_t out_scale)
+ auto i32 = types()->i32_type();
+ auto i128 = types()->i128_type();
+ auto function = BuildFunction("add_decimal128_decimal128", i128,
+ {
+ {"x_value", i128},
+ {"x_precision", i32},
+ {"x_scale", i32},
+ {"y_value", i128},
+ {"y_precision", i32},
+ {"y_scale", i32},
+ {"out_precision", i32},
+ {"out_scale", i32},
+ });
+
+ auto arg_iter = function->arg_begin();
+ ValueFull x(&arg_iter[0], &arg_iter[1], &arg_iter[2]);
+ ValueFull y(&arg_iter[3], &arg_iter[4], &arg_iter[5]);
+ ValueFull out(nullptr, &arg_iter[6], &arg_iter[7]);
+
+ auto entry = llvm::BasicBlock::Create(*context(), "entry", function);
+ ir_builder()->SetInsertPoint(entry);
+
+ // CPP :
+ // if (out_precision < 38) {
+ // return AddFastPath(x, y)
+ // } else {
+ // ret = AddWithOverflowCheck(x, y)
+ // if (ret.overflow)
+ // return AddLarge(x, y)
+ // else
+ // return ret.value;
+ // }
+ llvm::Value* lt_max_precision = ir_builder()->CreateICmpSLT(
+ out.precision(), types()->i32_constant(DecimalTypeUtil::kMaxPrecision));
+ auto then_lambda = [&] {
+ // fast-path add
+ return AddFastPath(x, y);
+ };
+ auto else_lambda = [&] {
+ if (kUseOverflowIntrinsics) {
+ // do the add and check if there was overflow
+ auto ret = AddWithOverflowCheck(x, y, out);
+
+ // if there is an overflow, switch to the AddLarge codepath.
+ return BuildIfElse(ret.overflow(), types()->i128_type(),
+ [&] { return AddLarge(x, y, out); },
+ [&] { return ret.value(); });
+ } else {
+ return AddLarge(x, y, out);
+ }
+ };
+ auto value =
+ BuildIfElse(lt_max_precision, types()->i128_type(), then_lambda, else_lambda);
+
+ // store result to out
+ ir_builder()->CreateRet(value);
+ return Status::OK();
+}
+
+Status DecimalIR::AddFunctions(Engine* engine) {
+ auto decimal_ir = std::make_shared<DecimalIR>(engine);
+
+ // Populate global variables used by decimal operations.
+ decimal_ir->AddGlobals(engine);
+
+ // Lookup intrinsic functions
+ decimal_ir->InitializeIntrinsics();
+
+ // build "add"
+ return decimal_ir->BuildAdd();
+}
+
+// Do an bitwise-or of all the overflow bits.
+llvm::Value* DecimalIR::GetCombinedOverflow(
+ std::vector<DecimalIR::ValueWithOverflow> vec) {
+ llvm::Value* res = types()->false_constant();
+ for (auto& val : vec) {
+ res = ir_builder()->CreateOr(res, val.overflow());
+ }
+ return res;
+}
+
+DecimalIR::ValueSplit DecimalIR::ValueSplit::MakeFromInt128(DecimalIR* decimal_ir,
+ llvm::Value* in) {
+ auto builder = decimal_ir->ir_builder();
+ auto types = decimal_ir->types();
+
+ auto high = builder->CreateLShr(in, types->i128_constant(64));
+ high = builder->CreateTrunc(high, types->i64_type());
+ auto low = builder->CreateTrunc(in, types->i64_type());
+ return ValueSplit(high, low);
+}
+
+/// Convert IR struct {%i64, %i64} to cpp class ValueSplit
+DecimalIR::ValueSplit DecimalIR::ValueSplit::MakeFromStruct(DecimalIR* decimal_ir,
+ llvm::Value* dstruct) {
+ auto builder = decimal_ir->ir_builder();
+ auto high = builder->CreateExtractValue(dstruct, 0);
+ auto low = builder->CreateExtractValue(dstruct, 1);
+ return DecimalIR::ValueSplit(high, low);
+}
+
+llvm::Value* DecimalIR::ValueSplit::AsInt128(DecimalIR* decimal_ir) const {
+ auto builder = decimal_ir->ir_builder();
+ auto types = decimal_ir->types();
+
+ auto value = builder->CreateSExt(high_, types->i128_type());
+ value = builder->CreateShl(value, types->i128_constant(64));
+ value = builder->CreateAdd(value, builder->CreateZExt(low_, types->i128_type()));
+ return value;
+}
+
+/// Convert IR struct {%i128, %i1} to cpp class ValueWithOverflow
+DecimalIR::ValueWithOverflow DecimalIR::ValueWithOverflow::MakeFromStruct(
+ DecimalIR* decimal_ir, llvm::Value* dstruct) {
+ auto builder = decimal_ir->ir_builder();
+ auto value = builder->CreateExtractValue(dstruct, 0);
+ auto overflow = builder->CreateExtractValue(dstruct, 1);
+ return DecimalIR::ValueWithOverflow(value, overflow);
+}
+
+/// Convert to IR struct {%i128, %i1}
+llvm::Value* DecimalIR::ValueWithOverflow::AsStruct(DecimalIR* decimal_ir) const {
+ auto builder = decimal_ir->ir_builder();
+
+ auto undef = llvm::UndefValue::get(decimal_ir->i128_with_overflow_struct_type_);
+ auto struct_val = builder->CreateInsertValue(undef, value(), 0);
+ return builder->CreateInsertValue(struct_val, overflow(), 1);
+}
+
+/// debug traces
+void DecimalIR::AddTrace(const std::string& fmt, std::vector<llvm::Value*> args) {
+ DCHECK(enable_ir_traces_);
+
+ auto ir_str = ir_builder()->CreateGlobalStringPtr(fmt);
+ args.insert(args.begin(), ir_str);
+ ir_builder()->CreateCall(module()->getFunction("printf"), args, "trace");
+}
+
+void DecimalIR::AddTrace32(const std::string& msg, llvm::Value* value) {
+ AddTrace("DECIMAL_IR_TRACE:: " + msg + " %d\n", {value});
+}
+
+void DecimalIR::AddTrace128(const std::string& msg, llvm::Value* value) {
+ // convert i128 into two i64s for printing
+ auto split = ValueSplit::MakeFromInt128(this, value);
+ AddTrace("DECIMAL_IR_TRACE:: " + msg + " %llx:%llx (%lld:%llu)\n",
+ {split.high(), split.low(), split.high(), split.low()});
+}
+
+} // namespace gandiva
diff --git a/cpp/src/gandiva/decimal_ir.h b/cpp/src/gandiva/decimal_ir.h
new file mode 100644
index 0000000..fae762c
--- /dev/null
+++ b/cpp/src/gandiva/decimal_ir.h
@@ -0,0 +1,171 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifndef GANDIVA_DECIMAL_ADD_IR_BUILDER_H
+#define GANDIVA_DECIMAL_ADD_IR_BUILDER_H
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gandiva/function_ir_builder.h"
+
+namespace gandiva {
+
+/// @brief Decimal IR functions
+class DecimalIR : public FunctionIRBuilder {
+ public:
+ explicit DecimalIR(Engine* engine)
+ : FunctionIRBuilder(engine), enable_ir_traces_(false) {}
+
+ /// Build decimal IR functions and add them to the engine.
+ static Status AddFunctions(Engine* engine);
+
+ void EnableTraces() { enable_ir_traces_ = true; }
+
+ private:
+ /// The intrinsic fn for divide with small divisors is about 10x slower, so not
+ /// using these.
+ static const bool kUseOverflowIntrinsics = false;
+
+ // Holder for an i128 value, along with its with scale and precision.
+ class ValueFull {
+ public:
+ ValueFull(llvm::Value* value, llvm::Value* precision, llvm::Value* scale)
+ : value_(value), precision_(precision), scale_(scale) {}
+
+ llvm::Value* value() const { return value_; }
+ llvm::Value* precision() const { return precision_; }
+ llvm::Value* scale() const { return scale_; }
+
+ private:
+ llvm::Value* value_;
+ llvm::Value* precision_;
+ llvm::Value* scale_;
+ };
+
+ // Holder for an i128 value, and a boolean indicating overflow.
+ class ValueWithOverflow {
+ public:
+ ValueWithOverflow(llvm::Value* value, llvm::Value* overflow)
+ : value_(value), overflow_(overflow) {}
+
+ // Make from IR struct
+ static ValueWithOverflow MakeFromStruct(DecimalIR* decimal_ir, llvm::Value* dstruct);
+
+ // Build a corresponding IR struct
+ llvm::Value* AsStruct(DecimalIR* decimal_ir) const;
+
+ llvm::Value* value() const { return value_; }
+ llvm::Value* overflow() const { return overflow_; }
+
+ private:
+ llvm::Value* value_;
+ llvm::Value* overflow_;
+ };
+
+ // Holder for an i128 value that is split into two i64s
+ class ValueSplit {
+ public:
+ ValueSplit(llvm::Value* high, llvm::Value* low) : high_(high), low_(low) {}
+
+ // Make from i128 value
+ static ValueSplit MakeFromInt128(DecimalIR* decimal_ir, llvm::Value* in);
+
+ // Make from IR struct
+ static ValueSplit MakeFromStruct(DecimalIR* decimal_ir, llvm::Value* dstruct);
+
+ // Combine the two parts into an i128
+ llvm::Value* AsInt128(DecimalIR* decimal_ir) const;
+
+ llvm::Value* high() const { return high_; }
+ llvm::Value* low() const { return low_; }
+
+ private:
+ llvm::Value* high_;
+ llvm::Value* low_;
+ };
+
+ // Add global variables to the module.
+ static void AddGlobals(Engine* engine);
+
+ // Initialize intrinsic functions that are used by decimal operations.
+ void InitializeIntrinsics();
+
+ // Create IR builder for decimal add function.
+ static Status MakeAdd(Engine* engine, std::shared_ptr<FunctionIRBuilder>* out);
+
+ // Get the multiplier for specified scale (i.e 10^scale)
+ llvm::Value* GetScaleMultiplier(llvm::Value* scale);
+
+ // Get the higher of the two scales
+ llvm::Value* GetHigherScale(llvm::Value* x_scale, llvm::Value* y_scale);
+
+ // Increase scale of 'in_value' by 'increase_scale_by'.
+ // - If 'increase_scale_by' is <= 0, does nothing.
+ llvm::Value* IncreaseScale(llvm::Value* in_value, llvm::Value* increase_scale_by);
+
+ // Similar to IncreaseScale. but, also check if there is overflow.
+ ValueWithOverflow IncreaseScaleWithOverflowCheck(llvm::Value* in_value,
+ llvm::Value* increase_scale_by);
+
+ // Reduce scale of 'in_value' by 'reduce_scale_by'.
+ // - If 'reduce_scale_by' is <= 0, does nothing.
+ llvm::Value* ReduceScale(llvm::Value* in_value, llvm::Value* reduce_scale_by);
+
+ // Fast path of add: guaranteed no overflow
+ llvm::Value* AddFastPath(const ValueFull& x, const ValueFull& y);
+
+ // Similar to AddFastPath, but check if there's an overflow.
+ ValueWithOverflow AddWithOverflowCheck(const ValueFull& x, const ValueFull& y,
+ const ValueFull& out);
+
+ // Do addition of large integers (both positive and negative).
+ llvm::Value* AddLarge(const ValueFull& x, const ValueFull& y, const ValueFull& out);
+
+ // Get the combined overflow (logical or).
+ llvm::Value* GetCombinedOverflow(std::vector<ValueWithOverflow> values);
+
+ // Build the function for adding decimals.
+ Status BuildAdd();
+
+ // Add a trace in IR code.
+ void AddTrace(const std::string& fmt, std::vector<llvm::Value*> args);
+
+ // Add a trace msg along with a 32-bit integer.
+ void AddTrace32(const std::string& msg, llvm::Value* value);
+
+ // Add a trace msg along with a 128-bit integer.
+ void AddTrace128(const std::string& msg, llvm::Value* value);
+
+ // name of the global variable having the array of scale multipliers.
+ static const char* kScaleMultipliersName;
+
+ // Intrinsic functions
+ llvm::Function* sadd_with_overflow_fn_;
+ llvm::Function* smul_with_overflow_fn_;
+
+ // struct { i128: value, i1: overflow}
+ llvm::Type* i128_with_overflow_struct_type_;
+
+ // if set to true, ir traces are enabled. Useful for debugging.
+ bool enable_ir_traces_;
+};
+
+} // namespace gandiva
+
+#endif // GANDIVA_FUNCTION_IR_BUILDER_H
diff --git a/cpp/src/gandiva/decimal_type_util.cc b/cpp/src/gandiva/decimal_type_util.cc
new file mode 100644
index 0000000..0ebfe66
--- /dev/null
+++ b/cpp/src/gandiva/decimal_type_util.cc
@@ -0,0 +1,80 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/decimal_type_util.h"
+#include "gandiva/logging.h"
+
+namespace gandiva {
+
+constexpr int32_t DecimalTypeUtil::kMaxDecimal32Precision;
+constexpr int32_t DecimalTypeUtil::kMaxDecimal64Precision;
+constexpr int32_t DecimalTypeUtil::kMaxPrecision;
+
+constexpr int32_t DecimalTypeUtil::kMaxScale;
+constexpr int32_t DecimalTypeUtil::kMinAdjustedScale;
+
+#define DCHECK_TYPE(type) \
+ { \
+ DCHECK_GE(type->scale(), 0); \
+ DCHECK_LE(type->precision(), kMaxPrecision); \
+ }
+
+// Implementation of decimal rules.
+Status DecimalTypeUtil::GetResultType(Op op, const Decimal128TypeVector& in_types,
+ Decimal128TypePtr* out_type) {
+ DCHECK_EQ(in_types.size(), 2);
+
+ *out_type = nullptr;
+ auto t1 = in_types[0];
+ auto t2 = in_types[1];
+ DCHECK_TYPE(t1);
+ DCHECK_TYPE(t2);
+
+ int32_t s1 = t1->scale();
+ int32_t s2 = t2->scale();
+ int32_t p1 = t1->precision();
+ int32_t p2 = t2->precision();
+ int32_t result_scale;
+ int32_t result_precision;
+
+ switch (op) {
+ case kOpAdd:
+ case kOpSubtract:
+ result_scale = std::max(s1, s2);
+ result_precision = std::max(p1 - s1, p2 - s2) + result_scale + 1;
+ break;
+
+ case kOpMultiply:
+ result_scale = s1 + s2;
+ result_precision = p1 + p2 + 1;
+ break;
+
+ case kOpDivide:
+ result_scale = std::max(kMinAdjustedScale, s1 + p2 + 1);
+ result_precision = p1 - s1 + s2 + result_scale;
+ break;
+
+ case kOpMod:
+ result_scale = std::max(s1, s2);
+ result_precision = std::min(p1 - s1, p2 - s2) + result_scale;
+ break;
+ }
+ *out_type = MakeAdjustedType(result_precision, result_scale);
+ return Status::OK();
+}
+
+} // namespace gandiva
diff --git a/cpp/src/gandiva/decimal_type_util.h b/cpp/src/gandiva/decimal_type_util.h
new file mode 100644
index 0000000..2c095c1
--- /dev/null
+++ b/cpp/src/gandiva/decimal_type_util.h
@@ -0,0 +1,90 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Adapted from Apache Impala
+
+#ifndef GANDIVA_DECIMAL_TYPE_SQL_H
+#define GANDIVA_DECIMAL_TYPE_SQL_H
+
+#include <algorithm>
+#include <memory>
+
+#include "gandiva/arrow.h"
+
+namespace gandiva {
+
+/// @brief Handles conversion of scale/precision for operations on decimal types.
+/// TODO : do validations for all of these.
+class DecimalTypeUtil {
+ public:
+ enum Op {
+ kOpAdd,
+ kOpSubtract,
+ kOpMultiply,
+ kOpDivide,
+ kOpMod,
+ };
+
+ /// The maximum precision representable by a 4-byte decimal
+ static constexpr int32_t kMaxDecimal32Precision = 9;
+
+ /// The maximum precision representable by a 8-byte decimal
+ static constexpr int32_t kMaxDecimal64Precision = 18;
+
+ /// The maximum precision representable by a 16-byte decimal
+ static constexpr int32_t kMaxPrecision = 38;
+
+ // The maximum scale representable.
+ static constexpr int32_t kMaxScale = kMaxPrecision;
+
+ // When operating on decimal inputs, the integer part of the output can exceed the
+ // max precision. In such cases, the scale can be reduced, upto a minimum of
+ // kMinAdjustedScale.
+ // * There is no strong reason for 6, but both SQLServer and Impala use 6 too.
+ static constexpr int32_t kMinAdjustedScale = 6;
+
+ // For specified operation and input scale/precision, determine the output
+ // scale/precision.
+ static Status GetResultType(Op op, const Decimal128TypeVector& in_types,
+ Decimal128TypePtr* out_type);
+
+ static Decimal128TypePtr MakeType(int32_t precision, int32_t scale);
+
+ private:
+ static Decimal128TypePtr MakeAdjustedType(int32_t precision, int32_t scale);
+};
+
+inline Decimal128TypePtr DecimalTypeUtil::MakeType(int32_t precision, int32_t scale) {
+ return std::dynamic_pointer_cast<arrow::Decimal128Type>(
+ arrow::decimal(precision, scale));
+}
+
+// Reduce the scale if possible so that precision stays <= kMaxPrecision
+inline Decimal128TypePtr DecimalTypeUtil::MakeAdjustedType(int32_t precision,
+ int32_t scale) {
+ if (precision > kMaxPrecision) {
+ int32_t min_scale = std::min(scale, kMinAdjustedScale);
+ int32_t delta = precision - kMaxPrecision;
+ precision = kMaxPrecision;
+ scale = std::max(scale - delta, min_scale);
+ }
+ return MakeType(precision, scale);
+}
+
+} // namespace gandiva
+
+#endif // GANDIVA_DECIMAL_TYPE_SQL_H
diff --git a/cpp/src/gandiva/decimal_type_util_test.cc b/cpp/src/gandiva/decimal_type_util_test.cc
new file mode 100644
index 0000000..a593990
--- /dev/null
+++ b/cpp/src/gandiva/decimal_type_util_test.cc
@@ -0,0 +1,58 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Adapted from Apache Impala
+
+#include <gtest/gtest.h>
+
+#include "gandiva/decimal_type_util.h"
+#include "tests/test_util.h"
+
+namespace gandiva {
+
+#define DECIMAL_TYPE(p, s) DecimalTypeUtil::MakeType(p, s)
+
+Decimal128TypePtr DoOp(DecimalTypeUtil::Op op, Decimal128TypePtr d1,
+ Decimal128TypePtr d2) {
+ Decimal128TypePtr ret_type;
+ EXPECT_OK(DecimalTypeUtil::GetResultType(op, {d1, d2}, &ret_type));
+ return ret_type;
+}
+
+TEST(DecimalResultTypes, Basic) {
+ EXPECT_ARROW_TYPE_EQUALS(
+ DECIMAL_TYPE(31, 10),
+ DoOp(DecimalTypeUtil::kOpAdd, DECIMAL_TYPE(30, 10), DECIMAL_TYPE(30, 10)));
+
+ EXPECT_ARROW_TYPE_EQUALS(
+ DECIMAL_TYPE(32, 6),
+ DoOp(DecimalTypeUtil::kOpAdd, DECIMAL_TYPE(30, 6), DECIMAL_TYPE(30, 5)));
+
+ EXPECT_ARROW_TYPE_EQUALS(
+ DECIMAL_TYPE(38, 9),
+ DoOp(DecimalTypeUtil::kOpAdd, DECIMAL_TYPE(30, 10), DECIMAL_TYPE(38, 10)));
+
+ EXPECT_ARROW_TYPE_EQUALS(
+ DECIMAL_TYPE(38, 9),
+ DoOp(DecimalTypeUtil::kOpAdd, DECIMAL_TYPE(38, 10), DECIMAL_TYPE(38, 38)));
+
+ EXPECT_ARROW_TYPE_EQUALS(
+ DECIMAL_TYPE(38, 6),
+ DoOp(DecimalTypeUtil::kOpAdd, DECIMAL_TYPE(38, 10), DECIMAL_TYPE(38, 2)));
+}
+
+} // namespace gandiva
diff --git a/cpp/src/gandiva/engine.cc b/cpp/src/gandiva/engine.cc
index da7a6d8..9aaafea 100644
--- a/cpp/src/gandiva/engine.cc
+++ b/cpp/src/gandiva/engine.cc
@@ -39,6 +39,7 @@
#include <llvm/Transforms/Scalar.h>
#include <llvm/Transforms/Scalar/GVN.h>
#include <llvm/Transforms/Vectorize.h>
+#include "gandiva/decimal_ir.h"
#include "gandiva/exported_funcs_registry.h"
namespace gandiva {
@@ -94,6 +95,10 @@ Status Engine::Make(std::shared_ptr<Configuration> config,
auto status = engine_obj->LoadPreCompiledIRFiles(config->byte_code_file_path());
ARROW_RETURN_NOT_OK(status);
+ // Add decimal functions
+ status = DecimalIR::AddFunctions(engine_obj.get());
+ ARROW_RETURN_NOT_OK(status);
+
*engine = std::move(engine_obj);
return Status::OK();
}
@@ -183,7 +188,7 @@ Status Engine::FinalizeModule(bool optimise_ir, bool dump_ir) {
// run the optimiser
llvm::PassManagerBuilder pass_builder;
- pass_builder.OptLevel = 2;
+ pass_builder.OptLevel = 3;
pass_builder.populateModulePassManager(*pass_manager);
pass_manager->run(*module_);
@@ -222,7 +227,7 @@ void Engine::DumpIR(std::string prefix) {
std::string str;
llvm::raw_string_ostream stream(str);
- module_->print(stream, NULL);
+ module_->print(stream, nullptr);
std::cout << "====" << prefix << "===" << str << "\n";
}
diff --git a/cpp/src/gandiva/engine.h b/cpp/src/gandiva/engine.h
index f377ebc..16b5a56 100644
--- a/cpp/src/gandiva/engine.h
+++ b/cpp/src/gandiva/engine.h
@@ -37,6 +37,8 @@
namespace gandiva {
+class FunctionIRBuilder;
+
/// \brief LLVM Execution engine wrapper.
class Engine {
public:
diff --git a/cpp/src/gandiva/expression_registry.cc b/cpp/src/gandiva/expression_registry.cc
index fb5a45e..1a087c9 100644
--- a/cpp/src/gandiva/expression_registry.cc
+++ b/cpp/src/gandiva/expression_registry.cc
@@ -136,10 +136,12 @@ void ExpressionRegistry::AddArrowTypesToVector(arrow::Type::type& type,
case arrow::Type::type::NA:
vector.push_back(arrow::null());
break;
+ case arrow::Type::type::DECIMAL:
+ vector.push_back(arrow::decimal(0, 0));
+ break;
case arrow::Type::type::FIXED_SIZE_BINARY:
case arrow::Type::type::MAP:
case arrow::Type::type::INTERVAL:
- case arrow::Type::type::DECIMAL:
case arrow::Type::type::LIST:
case arrow::Type::type::STRUCT:
case arrow::Type::type::UNION:
diff --git a/cpp/src/gandiva/function_ir_builder.cc b/cpp/src/gandiva/function_ir_builder.cc
new file mode 100644
index 0000000..1942739
--- /dev/null
+++ b/cpp/src/gandiva/function_ir_builder.cc
@@ -0,0 +1,81 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/function_ir_builder.h"
+
+namespace gandiva {
+
+llvm::Value* FunctionIRBuilder::BuildIfElse(llvm::Value* condition,
+ llvm::Type* return_type,
+ std::function<llvm::Value*()> then_func,
+ std::function<llvm::Value*()> else_func) {
+ llvm::IRBuilder<>* builder = ir_builder();
+ llvm::Function* function = builder->GetInsertBlock()->getParent();
+ DCHECK_NE(function, nullptr);
+
+ // Create blocks for the then, else and merge cases.
+ llvm::BasicBlock* then_bb = llvm::BasicBlock::Create(*context(), "then", function);
+ llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(*context(), "else", function);
+ llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(*context(), "merge", function);
+
+ builder->CreateCondBr(condition, then_bb, else_bb);
+
+ // Emit the then block.
+ builder->SetInsertPoint(then_bb);
+ auto then_value = then_func();
+ builder->CreateBr(merge_bb);
+
+ // refresh then_bb for phi (could have changed due to code generation of then_value).
+ then_bb = builder->GetInsertBlock();
+
+ // Emit the else block.
+ builder->SetInsertPoint(else_bb);
+ auto else_value = else_func();
+ builder->CreateBr(merge_bb);
+
+ // refresh else_bb for phi (could have changed due to code generation of else_value).
+ else_bb = builder->GetInsertBlock();
+
+ // Emit the merge block.
+ builder->SetInsertPoint(merge_bb);
+ llvm::PHINode* result_value = builder->CreatePHI(return_type, 2, "res_value");
+ result_value->addIncoming(then_value, then_bb);
+ result_value->addIncoming(else_value, else_bb);
+ return result_value;
+}
+
+llvm::Function* FunctionIRBuilder::BuildFunction(const std::string& function_name,
+ llvm::Type* return_type,
+ std::vector<NamedArg> in_args) {
+ std::vector<llvm::Type*> arg_types;
+ for (auto& arg : in_args) {
+ arg_types.push_back(arg.type);
+ }
+ auto prototype = llvm::FunctionType::get(return_type, arg_types, false /*isVarArg*/);
+ auto function = llvm::Function::Create(prototype, llvm::GlobalValue::ExternalLinkage,
+ function_name, module());
+
+ uint32_t i = 0;
+ for (auto& fn_arg : function->args()) {
+ DCHECK_LT(i, in_args.size());
+ fn_arg.setName(in_args[i].name);
+ ++i;
+ }
+ return function;
+}
+
+} // namespace gandiva
diff --git a/cpp/src/gandiva/function_ir_builder.h b/cpp/src/gandiva/function_ir_builder.h
new file mode 100644
index 0000000..7d6003a
--- /dev/null
+++ b/cpp/src/gandiva/function_ir_builder.h
@@ -0,0 +1,64 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifndef GANDIVA_FUNCTION_IR_BUILDER_H
+#define GANDIVA_FUNCTION_IR_BUILDER_H
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gandiva/engine.h"
+#include "gandiva/gandiva_aliases.h"
+#include "gandiva/llvm_types.h"
+
+namespace gandiva {
+
+/// @brief Base class for building IR functions.
+class FunctionIRBuilder {
+ public:
+ explicit FunctionIRBuilder(Engine* engine) : engine_(engine) {}
+ virtual ~FunctionIRBuilder() = default;
+
+ protected:
+ LLVMTypes* types() { return engine_->types(); }
+ llvm::Module* module() { return engine_->module(); }
+ llvm::LLVMContext* context() { return engine_->context(); }
+ llvm::IRBuilder<>* ir_builder() { return engine_->ir_builder(); }
+
+ /// Build an if-else block.
+ llvm::Value* BuildIfElse(llvm::Value* condition, llvm::Type* return_type,
+ std::function<llvm::Value*()> then_func,
+ std::function<llvm::Value*()> else_func);
+
+ struct NamedArg {
+ std::string name;
+ llvm::Type* type;
+ };
+
+ /// Build llvm fn.
+ llvm::Function* BuildFunction(const std::string& function_name, llvm::Type* return_type,
+ std::vector<NamedArg> in_args);
+
+ private:
+ Engine* engine_;
+};
+
+} // namespace gandiva
+
+#endif // GANDIVA_FUNCTION_IR_BUILDER_H
diff --git a/cpp/src/gandiva/function_registry.cc b/cpp/src/gandiva/function_registry.cc
index 83d80b4..452cb63 100644
--- a/cpp/src/gandiva/function_registry.cc
+++ b/cpp/src/gandiva/function_registry.cc
@@ -29,23 +29,6 @@
namespace gandiva {
-using arrow::binary;
-using arrow::boolean;
-using arrow::date64;
-using arrow::float32;
-using arrow::float64;
-using arrow::int16;
-using arrow::int32;
-using arrow::int64;
-using arrow::int8;
-using arrow::uint16;
-using arrow::uint32;
-using arrow::uint64;
-using arrow::uint8;
-using arrow::utf8;
-using std::iterator;
-using std::vector;
-
FunctionRegistry::iterator FunctionRegistry::begin() const {
return &(*pc_registry_.begin());
}
@@ -89,7 +72,7 @@ SignatureMap FunctionRegistry::InitPCMap() {
const NativeFunction* FunctionRegistry::LookupSignature(
const FunctionSignature& signature) const {
auto got = pc_registry_map_.find(&signature);
- return got == pc_registry_map_.end() ? NULL : got->second;
+ return got == pc_registry_map_.end() ? nullptr : got->second;
}
} // namespace gandiva
diff --git a/cpp/src/gandiva/function_registry_arithmetic.cc b/cpp/src/gandiva/function_registry_arithmetic.cc
index 800bc49..c5a798c 100644
--- a/cpp/src/gandiva/function_registry_arithmetic.cc
+++ b/cpp/src/gandiva/function_registry_arithmetic.cc
@@ -57,6 +57,8 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, int64, int32, int32),
BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, int64, int64, int64),
+ BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(add, decimal128),
+
BINARY_RELATIONAL_BOOL_FN(equal),
BINARY_RELATIONAL_BOOL_FN(not_equal),
diff --git a/cpp/src/gandiva/function_registry_common.h b/cpp/src/gandiva/function_registry_common.h
index 78babce..3ae065a 100644
--- a/cpp/src/gandiva/function_registry_common.h
+++ b/cpp/src/gandiva/function_registry_common.h
@@ -53,6 +53,7 @@ inline DataTypePtr time32() { return arrow::time32(arrow::TimeUnit::MILLI); }
inline DataTypePtr time64() { return arrow::time64(arrow::TimeUnit::MICRO); }
inline DataTypePtr timestamp() { return arrow::timestamp(arrow::TimeUnit::MILLI); }
+inline DataTypePtr decimal128() { return arrow::decimal(0, 0); }
struct KeyHash {
std::size_t operator()(const FunctionSignature* k) const { return k->Hash(); }
diff --git a/cpp/src/gandiva/function_signature.h b/cpp/src/gandiva/function_signature.h
index e5dff24..ee82abc 100644
--- a/cpp/src/gandiva/function_signature.h
+++ b/cpp/src/gandiva/function_signature.h
@@ -56,10 +56,22 @@ class FunctionSignature {
std::string ToString() const;
private:
- // TODO : for some of the types, this shouldn't match type specific data. eg. for
- // decimals, this shouldn't match precision/scale.
bool DataTypeEquals(const DataTypePtr left, const DataTypePtr right) const {
- return left->Equals(right);
+ if (left->id() == right->id()) {
+ switch (left->id()) {
+ case arrow::Type::DECIMAL: {
+ // For decimal types, the precision/scale isn't part of the signature.
+ auto dleft = arrow::internal::checked_cast<arrow::DecimalType*>(left.get());
+ auto dright = arrow::internal::checked_cast<arrow::DecimalType*>(right.get());
+ return (dleft != NULL) && (dright != NULL) &&
+ (dleft->byte_width() == dright->byte_width());
+ }
+ default:
+ return left->Equals(right);
+ }
+ } else {
+ return false;
+ }
}
std::string base_name_;
diff --git a/cpp/src/gandiva/jni/CMakeLists.txt b/cpp/src/gandiva/jni/CMakeLists.txt
index a07d390..afc7fad 100644
--- a/cpp/src/gandiva/jni/CMakeLists.txt
+++ b/cpp/src/gandiva/jni/CMakeLists.txt
@@ -78,5 +78,5 @@ add_dependencies(gandiva ${GANDIVA_JNI_LIBRARIES})
# statically linked stdc++ has conflicts with stdc++ loaded by other libraries.
if (NOT APPLE)
set_target_properties(gandiva_jni_shared PROPERTIES
- LINK_FLAGS "-Wl,--version-script=${CMAKE_SOURCE_DIR}/src/gandiva/jni/symbols.map")
+ LINK_FLAGS "-Wl,--no-as-needed -Wl,--version-script=${CMAKE_SOURCE_DIR}/src/gandiva/jni/symbols.map")
endif()
diff --git a/cpp/src/gandiva/jni/expression_registry_helper.cc b/cpp/src/gandiva/jni/expression_registry_helper.cc
index 5227329..b5c6880 100644
--- a/cpp/src/gandiva/jni/expression_registry_helper.cc
+++ b/cpp/src/gandiva/jni/expression_registry_helper.cc
@@ -121,10 +121,15 @@ void ArrowToProtobuf(DataTypePtr type, types::ExtGandivaType* gandiva_data_type)
case arrow::Type::type::NA:
gandiva_data_type->set_type(types::GandivaType::NONE);
break;
+ case arrow::Type::type::DECIMAL: {
+ gandiva_data_type->set_type(types::GandivaType::DECIMAL);
+ gandiva_data_type->set_precision(0);
+ gandiva_data_type->set_scale(0);
+ break;
+ }
case arrow::Type::type::FIXED_SIZE_BINARY:
case arrow::Type::type::MAP:
case arrow::Type::type::INTERVAL:
- case arrow::Type::type::DECIMAL:
case arrow::Type::type::LIST:
case arrow::Type::type::STRUCT:
case arrow::Type::type::UNION:
diff --git a/cpp/src/gandiva/jni/jni_common.cc b/cpp/src/gandiva/jni/jni_common.cc
index 639ad36..7ad0d6d 100644
--- a/cpp/src/gandiva/jni/jni_common.cc
+++ b/cpp/src/gandiva/jni/jni_common.cc
@@ -381,6 +381,12 @@ NodePtr ProtoTypeToNode(const types::TreeNode& node) {
return TreeExprBuilder::MakeBinaryLiteral(node.binarynode().value());
}
+ if (node.has_decimalnode()) {
+ std::string value = node.decimalnode().value();
+ gandiva::Decimal128Full literal(value, node.decimalnode().precision(),
+ node.decimalnode().scale());
+ return TreeExprBuilder::MakeDecimalLiteral(literal);
+ }
std::cerr << "Unknown node type in protobuf\n";
return nullptr;
}
diff --git a/cpp/src/gandiva/literal_holder.h b/cpp/src/gandiva/literal_holder.h
index 0a65ea2..ad6afce 100644
--- a/cpp/src/gandiva/literal_holder.h
+++ b/cpp/src/gandiva/literal_holder.h
@@ -22,11 +22,14 @@
#include <boost/variant.hpp>
+#include <arrow/type.h>
+#include "gandiva/decimal_full.h"
+
namespace gandiva {
using LiteralHolder =
boost::variant<bool, float, double, int8_t, int16_t, int32_t, int64_t, uint8_t,
- uint16_t, uint32_t, uint64_t, std::string>;
+ uint16_t, uint32_t, uint64_t, std::string, Decimal128Full>;
} // namespace gandiva
diff --git a/cpp/src/gandiva/llvm_generator.cc b/cpp/src/gandiva/llvm_generator.cc
index 50f147b..9ddbe93 100644
--- a/cpp/src/gandiva/llvm_generator.cc
+++ b/cpp/src/gandiva/llvm_generator.cc
@@ -399,6 +399,17 @@ llvm::Value* LLVMGenerator::AddFunctionCall(const std::string& full_name,
return value;
}
+std::shared_ptr<DecimalLValue> LLVMGenerator::BuildDecimalLValue(llvm::Value* value,
+ DataTypePtr arrow_type) {
+ // only decimals of size 128-bit supported.
+ DCHECK(is_decimal_128(arrow_type));
+ auto decimal_type =
+ arrow::internal::checked_cast<arrow::DecimalType*>(arrow_type.get());
+ return std::make_shared<DecimalLValue>(value, nullptr,
+ types()->i32_constant(decimal_type->precision()),
+ types()->i32_constant(decimal_type->scale()));
+}
+
#define ADD_VISITOR_TRACE(...) \
if (generator_->enable_ir_traces_) { \
generator_->AddTrace(__VA_ARGS__); \
@@ -422,20 +433,33 @@ LLVMGenerator::Visitor::Visitor(LLVMGenerator* generator, llvm::Function* functi
void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex& dex) {
llvm::IRBuilder<>* builder = ir_builder();
-
llvm::Value* slot_ref = GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field());
-
llvm::Value* slot_value;
- if (dex.FieldType()->id() == arrow::Type::BOOL) {
- slot_value = generator_->GetPackedBitValue(slot_ref, loop_var_);
- } else {
- llvm::Value* slot_offset = builder->CreateGEP(slot_ref, loop_var_);
- slot_value = builder->CreateLoad(slot_offset, dex.FieldName());
- }
+ std::shared_ptr<LValue> lvalue;
+
+ switch (dex.FieldType()->id()) {
+ case arrow::Type::BOOL:
+ slot_value = generator_->GetPackedBitValue(slot_ref, loop_var_);
+ lvalue = std::make_shared<LValue>(slot_value);
+ break;
+ case arrow::Type::DECIMAL: {
+ auto slot_offset = builder->CreateGEP(slot_ref, loop_var_);
+ slot_value = builder->CreateLoad(slot_offset, dex.FieldName());
+ lvalue = generator_->BuildDecimalLValue(slot_value, dex.FieldType());
+ break;
+ }
+
+ default: {
+ auto slot_offset = builder->CreateGEP(slot_ref, loop_var_);
+ slot_value = builder->CreateLoad(slot_offset, dex.FieldName());
+ lvalue = std::make_shared<LValue>(slot_value);
+ break;
+ }
+ }
ADD_VISITOR_TRACE("visit fixed-len data vector " + dex.FieldName() + " value %T",
slot_value);
- result_.reset(new LValue(slot_value));
+ result_ = lvalue;
}
void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex& dex) {
@@ -572,6 +596,19 @@ void LLVMGenerator::Visitor::Visit(const LiteralDex& dex) {
value = types->i64_constant(boost::get<int64_t>(dex.holder()));
break;
+ case arrow::Type::DECIMAL: {
+ // build code for struct
+ auto decimal_value = boost::get<Decimal128Full>(dex.holder());
+ auto int_value =
+ llvm::ConstantInt::get(llvm::Type::getInt128Ty(*generator_->context()),
+ decimal_value.value().ToIntegerString(), 10);
+ auto type = arrow::decimal(decimal_value.precision(), decimal_value.scale());
+ auto lvalue = generator_->BuildDecimalLValue(int_value, type);
+ // set it as the l-value and return.
+ result_ = lvalue;
+ return;
+ }
+
default:
DCHECK(0);
}
@@ -589,13 +626,14 @@ void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex& dex) {
auto params = BuildParams(dex.function_holder().get(), dex.args(), false,
native_function->NeedsContext());
+ auto arrow_return_type = dex.func_descriptor()->return_type();
if (native_function->CanReturnErrors()) {
// slow path : if a function can return errors, skip invoking the function
// unless all of the input args are valid. Otherwise, it can cause spurious errors.
llvm::IRBuilder<>* builder = ir_builder();
LLVMTypes* types = generator_->types();
- auto arrow_type_id = native_function->signature().ret_type()->id();
+ auto arrow_type_id = arrow_return_type->id();
auto result_type = types->IRType(arrow_type_id);
// Build combined validity of the args.
@@ -609,7 +647,7 @@ void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex& dex) {
auto then_lambda = [&] {
ADD_VISITOR_TRACE("fn " + function_name +
" can return errors : all args valid, invoke fn");
- return BuildFunctionCall(native_function, ¶ms);
+ return BuildFunctionCall(native_function, arrow_return_type, ¶ms);
};
// else block
@@ -624,10 +662,10 @@ void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex& dex) {
return std::make_shared<LValue>(else_value, else_value_len);
};
- result_ = BuildIfElse(is_valid, then_lambda, else_lambda, result_type);
+ result_ = BuildIfElse(is_valid, then_lambda, else_lambda, arrow_return_type);
} else {
// fast path : invoke function without computing validities.
- result_ = BuildFunctionCall(native_function, ¶ms);
+ result_ = BuildFunctionCall(native_function, arrow_return_type, ¶ms);
}
}
@@ -639,7 +677,8 @@ void LLVMGenerator::Visitor::Visit(const NullableNeverFuncDex& dex) {
auto params = BuildParams(dex.function_holder().get(), dex.args(), true,
native_function->NeedsContext());
- result_ = BuildFunctionCall(native_function, ¶ms);
+ auto arrow_return_type = dex.func_descriptor()->return_type();
+ result_ = BuildFunctionCall(native_function, arrow_return_type, ¶ms);
}
void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex& dex) {
@@ -659,7 +698,8 @@ void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex& dex) {
new llvm::AllocaInst(types->i8_type(), 0, "result_valid", entry_block_);
params.push_back(result_valid_ptr);
- result_ = BuildFunctionCall(native_function, ¶ms);
+ auto arrow_return_type = dex.func_descriptor()->return_type();
+ result_ = BuildFunctionCall(native_function, arrow_return_type, ¶ms);
// load the result validity and truncate to i1.
llvm::Value* result_valid_i8 = builder->CreateLoad(result_valid_ptr);
@@ -672,7 +712,6 @@ void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex& dex) {
void LLVMGenerator::Visitor::Visit(const IfDex& dex) {
ADD_VISITOR_TRACE("visit IfExpression");
llvm::IRBuilder<>* builder = ir_builder();
- LLVMTypes* types = generator_->types();
// Evaluate condition.
LValuePtr if_condition = BuildValueAndValidity(dex.condition_vv());
@@ -714,9 +753,8 @@ void LLVMGenerator::Visitor::Visit(const IfDex& dex) {
};
// build the if-else condition.
- auto result_type = types->IRType(dex.result_type()->id());
- result_ = BuildIfElse(validAndMatched, then_lambda, else_lambda, result_type);
- if (result_type == types->i8_ptr_type()) {
+ result_ = BuildIfElse(validAndMatched, then_lambda, else_lambda, dex.result_type());
+ if (arrow::is_binary_like(dex.result_type()->id())) {
ADD_VISITOR_TRACE("IfElse result length %T", result_->length());
}
ADD_VISITOR_TRACE("IfElse result value %T", result_->data());
@@ -906,7 +944,7 @@ void LLVMGenerator::Visitor::VisitInExpression(const InExprDexBase<Type>& dex) {
LValuePtr LLVMGenerator::Visitor::BuildIfElse(llvm::Value* condition,
std::function<LValuePtr()> then_func,
std::function<LValuePtr()> else_func,
- llvm::Type* result_type) {
+ DataTypePtr result_type) {
llvm::IRBuilder<>* builder = ir_builder();
llvm::LLVMContext* context = generator_->context();
LLVMTypes* types = generator_->types();
@@ -936,17 +974,31 @@ LValuePtr LLVMGenerator::Visitor::BuildIfElse(llvm::Value* condition,
// Emit the merge block.
builder->SetInsertPoint(merge_bb);
- llvm::PHINode* result_value = builder->CreatePHI(result_type, 2, "res_value");
+ auto llvm_type = types->IRType(result_type->id());
+ llvm::PHINode* result_value = builder->CreatePHI(llvm_type, 2, "res_value");
result_value->addIncoming(then_lvalue->data(), then_bb);
result_value->addIncoming(else_lvalue->data(), else_bb);
- llvm::PHINode* result_length = nullptr;
- if (result_type == types->i8_ptr_type()) {
- result_length = builder->CreatePHI(types->i32_type(), 2, "res_length");
- result_length->addIncoming(then_lvalue->length(), then_bb);
- result_length->addIncoming(else_lvalue->length(), else_bb);
+ LValuePtr ret;
+ switch (result_type->id()) {
+ case arrow::Type::STRING: {
+ llvm::PHINode* result_length;
+ result_length = builder->CreatePHI(types->i32_type(), 2, "res_length");
+ result_length->addIncoming(then_lvalue->length(), then_bb);
+ result_length->addIncoming(else_lvalue->length(), else_bb);
+ ret = std::make_shared<LValue>(result_value, result_length);
+ break;
+ }
+
+ case arrow::Type::DECIMAL:
+ ret = generator_->BuildDecimalLValue(result_value, result_type);
+ break;
+
+ default:
+ ret = std::make_shared<LValue>(result_value);
+ break;
}
- return std::make_shared<LValue>(result_value, result_length);
+ return ret;
}
LValuePtr LLVMGenerator::Visitor::BuildValueAndValidity(const ValueValidityPair& pair) {
@@ -963,25 +1015,46 @@ LValuePtr LLVMGenerator::Visitor::BuildValueAndValidity(const ValueValidityPair&
}
LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func,
+ DataTypePtr arrow_return_type,
std::vector<llvm::Value*>* params) {
- auto arrow_return_type = func->signature().ret_type()->id();
- auto llvm_return_type = generator_->types()->IRType(arrow_return_type);
-
- // add extra arg for return length for variable len return types (alloced on stack).
- llvm::AllocaInst* result_len_ptr = nullptr;
- if (arrow::is_binary_like(arrow_return_type)) {
- result_len_ptr = new llvm::AllocaInst(generator_->types()->i32_type(), 0,
- "result_len", entry_block_);
- params->push_back(result_len_ptr);
- has_arena_allocs_ = true;
- }
+ auto types = generator_->types();
+ auto arrow_return_type_id = arrow_return_type->id();
+ auto llvm_return_type = types->IRType(arrow_return_type_id);
+
+ if (arrow_return_type_id == arrow::Type::DECIMAL) {
+ // For decimal fns, the output precision/scale are passed along as parameters.
+ //
+ // convert from this :
+ // out = add_decimal(v1, p1, s1, v2, p2, s2)
+ // to:
+ // out = add_decimal(v1, p1, s1, v2, p2, s2, out_p, out_s)
+
+ // Append the out_precision and out_scale
+ auto ret_lvalue = generator_->BuildDecimalLValue(nullptr, arrow_return_type);
+ params->push_back(ret_lvalue->precision());
+ params->push_back(ret_lvalue->scale());
+
+ // Make the function call
+ auto out = generator_->AddFunctionCall(func->pc_name(), llvm_return_type, *params);
+ ret_lvalue->set_data(out);
+ return ret_lvalue;
+ } else {
+ // add extra arg for return length for variable len return types (alloced on stack).
+ llvm::AllocaInst* result_len_ptr = nullptr;
+ if (arrow::is_binary_like(arrow_return_type_id)) {
+ result_len_ptr = new llvm::AllocaInst(generator_->types()->i32_type(), 0,
+ "result_len", entry_block_);
+ params->push_back(result_len_ptr);
+ has_arena_allocs_ = true;
+ }
- // Make the function call
- llvm::IRBuilder<>* builder = ir_builder();
- auto value = generator_->AddFunctionCall(func->pc_name(), llvm_return_type, *params);
- auto value_len =
- (result_len_ptr == nullptr) ? nullptr : builder->CreateLoad(result_len_ptr);
- return std::make_shared<LValue>(value, value_len);
+ // Make the function call
+ llvm::IRBuilder<>* builder = ir_builder();
+ auto value = generator_->AddFunctionCall(func->pc_name(), llvm_return_type, *params);
+ auto value_len =
+ (result_len_ptr == nullptr) ? nullptr : builder->CreateLoad(result_len_ptr);
+ return std::make_shared<LValue>(value, value_len);
+ }
}
std::vector<llvm::Value*> LLVMGenerator::Visitor::BuildParams(
@@ -1007,12 +1080,9 @@ std::vector<llvm::Value*> LLVMGenerator::Visitor::BuildParams(
DexPtr value_expr = pair->value_expr();
value_expr->Accept(*this);
LValue& result_ref = *result();
- params.push_back(result_ref.data());
- // build length (for var len data types)
- if (result_ref.length() != nullptr) {
- params.push_back(result_ref.length());
- }
+ // append all the parameters corresponding to this LValue.
+ result_ref.AppendFunctionParams(¶ms);
// build validity.
if (with_validity) {
diff --git a/cpp/src/gandiva/llvm_generator.h b/cpp/src/gandiva/llvm_generator.h
index 49f209d..937e5ac 100644
--- a/cpp/src/gandiva/llvm_generator.h
+++ b/cpp/src/gandiva/llvm_generator.h
@@ -119,12 +119,13 @@ class LLVMGenerator {
bool with_validity, bool with_context);
// Generate code to onvoke a function call.
- LValuePtr BuildFunctionCall(const NativeFunction* func,
+ LValuePtr BuildFunctionCall(const NativeFunction* func, DataTypePtr arrow_return_type,
std::vector<llvm::Value*>* params);
// Generate code for an if-else condition.
LValuePtr BuildIfElse(llvm::Value* condition, std::function<LValuePtr()> then_func,
- std::function<LValuePtr()> else_func, llvm::Type* result_type);
+ std::function<LValuePtr()> else_func,
+ DataTypePtr arrow_return_type);
// Switch to the entry_block and get reference of the validity/value/offsets buffer
llvm::Value* GetBufferReference(int idx, BufferType buffer_type, FieldPtr field);
@@ -184,6 +185,10 @@ class LLVMGenerator {
void ClearPackedBitValueIfFalse(llvm::Value* bitmap, llvm::Value* position,
llvm::Value* value);
+ // Generate code to build a DecimalLValue with specified value/precision/scale.
+ std::shared_ptr<DecimalLValue> BuildDecimalLValue(llvm::Value* value,
+ DataTypePtr arrow_type);
+
/// Generate code to make a function call (to a pre-compiled IR function) which takes
/// 'args' and has a return type 'ret_type'.
llvm::Value* AddFunctionCall(const std::string& full_name, llvm::Type* ret_type,
diff --git a/cpp/src/gandiva/llvm_types.cc b/cpp/src/gandiva/llvm_types.cc
index 0b89d96..18ff627 100644
--- a/cpp/src/gandiva/llvm_types.cc
+++ b/cpp/src/gandiva/llvm_types.cc
@@ -40,6 +40,7 @@ LLVMTypes::LLVMTypes(llvm::LLVMContext& context) : context_(context) {
{arrow::Type::type::TIMESTAMP, i64_type()},
{arrow::Type::type::STRING, i8_ptr_type()},
{arrow::Type::type::BINARY, i8_ptr_type()},
+ {arrow::Type::type::DECIMAL, i128_type()},
};
}
diff --git a/cpp/src/gandiva/llvm_types.h b/cpp/src/gandiva/llvm_types.h
index dab47d0..9cf4dd5 100644
--- a/cpp/src/gandiva/llvm_types.h
+++ b/cpp/src/gandiva/llvm_types.h
@@ -43,6 +43,8 @@ class LLVMTypes {
llvm::Type* i64_type() { return llvm::Type::getInt64Ty(context_); }
+ llvm::Type* i128_type() { return llvm::Type::getInt128Ty(context_); }
+
llvm::Type* float_type() { return llvm::Type::getFloatTy(context_); }
llvm::Type* double_type() { return llvm::Type::getDoubleTy(context_); }
@@ -53,12 +55,19 @@ class LLVMTypes {
llvm::PointerType* i64_ptr_type() { return llvm::PointerType::get(i64_type(), 0); }
- llvm::PointerType* ptr_type(llvm::Type* base_type) {
- return llvm::PointerType::get(base_type, 0);
+ llvm::PointerType* i128_ptr_type() { return llvm::PointerType::get(i128_type(), 0); }
+
+ llvm::StructType* i128_split_type() {
+ // struct with high/low bits (see decimal_ops.cc:DecimalSplit)
+ return llvm::StructType::get(context_, {i64_type(), i64_type()}, false);
}
llvm::Type* void_type() { return llvm::Type::getVoidTy(context_); }
+ llvm::PointerType* ptr_type(llvm::Type* base_type) {
+ return llvm::PointerType::get(base_type, 0);
+ }
+
llvm::Constant* true_constant() {
return llvm::ConstantInt::get(context_, llvm::APInt(1, 1));
}
@@ -87,6 +96,18 @@ class LLVMTypes {
return llvm::ConstantInt::get(context_, llvm::APInt(64, val));
}
+ llvm::Constant* i128_constant(int64_t val) {
+ return llvm::ConstantInt::get(context_, llvm::APInt(128, val));
+ }
+
+ llvm::Constant* i128_zero() {
+ return llvm::ConstantInt::get(context_, llvm::APInt(128, 0));
+ }
+
+ llvm::Constant* i128_one() {
+ return llvm::ConstantInt::get(context_, llvm::APInt(128, 1));
+ }
+
llvm::Constant* float_constant(float val) {
return llvm::ConstantFP::get(float_type(), val);
}
diff --git a/cpp/src/gandiva/lvalue.h b/cpp/src/gandiva/lvalue.h
index 2ff03dc..ce5040f 100644
--- a/cpp/src/gandiva/lvalue.h
+++ b/cpp/src/gandiva/lvalue.h
@@ -18,9 +18,11 @@
#ifndef GANDIVA_LVALUE_H
#define GANDIVA_LVALUE_H
-#include "arrow/util/macros.h"
+#include <vector>
#include <llvm/IR/IRBuilder.h>
+#include "arrow/util/macros.h"
+#include "gandiva/logging.h"
namespace gandiva {
@@ -30,17 +32,48 @@ class LValue {
explicit LValue(llvm::Value* data, llvm::Value* length = NULLPTR,
llvm::Value* validity = NULLPTR)
: data_(data), length_(length), validity_(validity) {}
+ virtual ~LValue() = default;
llvm::Value* data() { return data_; }
llvm::Value* length() { return length_; }
llvm::Value* validity() { return validity_; }
+ void set_data(llvm::Value* data) { data_ = data; }
+
+ // Append the params required when passing this as a function parameter.
+ virtual void AppendFunctionParams(std::vector<llvm::Value*>* params) {
+ params->push_back(data_);
+ if (length_ != NULLPTR) {
+ params->push_back(length_);
+ }
+ }
+
private:
llvm::Value* data_;
llvm::Value* length_;
llvm::Value* validity_;
};
+class DecimalLValue : public LValue {
+ public:
+ DecimalLValue(llvm::Value* data, llvm::Value* validity, llvm::Value* precision,
+ llvm::Value* scale)
+ : LValue(data, NULLPTR, validity), precision_(precision), scale_(scale) {}
+
+ llvm::Value* precision() { return precision_; }
+ llvm::Value* scale() { return scale_; }
+
+ void AppendFunctionParams(std::vector<llvm::Value*>* params) override {
+ LValue::AppendFunctionParams(params);
+ params->push_back(precision_);
+ params->push_back(scale_);
+ }
+
+ private:
+ llvm::Value* precision_;
+ llvm::Value* scale_;
+};
+
} // namespace gandiva
#endif // GANDIVA_LVALUE_H
diff --git a/cpp/src/gandiva/precompiled/CMakeLists.txt b/cpp/src/gandiva/precompiled/CMakeLists.txt
index 21a74bd..eab0b90 100644
--- a/cpp/src/gandiva/precompiled/CMakeLists.txt
+++ b/cpp/src/gandiva/precompiled/CMakeLists.txt
@@ -20,12 +20,16 @@ project(gandiva)
set(PRECOMPILED_SRCS
arithmetic_ops.cc
bitmap.cc
+ decimal_ops.cc
+ decimal_wrapper.cc
extended_math_ops.cc
hash.cc
print.cc
string_ops.cc
time.cc
- timestamp_arithmetic.cc)
+ timestamp_arithmetic.cc
+ ../../arrow/status.cc
+ ../../arrow/util/decimal.cc)
# Create bitcode for each of the source files.
foreach(SRC_FILE ${PRECOMPILED_SRCS})
@@ -35,7 +39,10 @@ foreach(SRC_FILE ${PRECOMPILED_SRCS})
add_custom_command(
OUTPUT ${BC_FILE}
COMMAND ${CLANG_EXECUTABLE}
- -std=c++11 -emit-llvm -O2 -c ${ABSOLUTE_SRC} -o ${BC_FILE}
+ -std=c++11 -emit-llvm
+ -DNDEBUG # DCHECK macros not implemented in precompiled code
+ -fno-use-cxa-atexit # Workaround for unresolved __dso_handle
+ -O3 -c ${ABSOLUTE_SRC} -o ${BC_FILE}
-I${CMAKE_SOURCE_DIR}/src
DEPENDS ${SRC_FILE})
list(APPEND BC_FILES ${BC_FILE})
@@ -77,4 +84,5 @@ if (ARROW_BUILD_TESTS)
add_precompiled_unit_test(string_ops_test.cc string_ops.cc ../context_helper.cc)
add_precompiled_unit_test(arithmetic_ops_test.cc arithmetic_ops.cc ../context_helper.cc)
add_precompiled_unit_test(extended_math_ops_test.cc extended_math_ops.cc ../context_helper.cc)
+ add_precompiled_unit_test(decimal_ops_test.cc decimal_ops.cc ../decimal_type_util.cc)
endif()
diff --git a/cpp/src/gandiva/precompiled/decimal_ops.cc b/cpp/src/gandiva/precompiled/decimal_ops.cc
new file mode 100644
index 0000000..57cb83e
--- /dev/null
+++ b/cpp/src/gandiva/precompiled/decimal_ops.cc
@@ -0,0 +1,219 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Alogrithms adapted from Apache Impala
+
+#include "gandiva/precompiled/decimal_ops.h"
+
+#include <algorithm>
+
+#include "gandiva/decimal_type_util.h"
+#include "gandiva/logging.h"
+
+namespace gandiva {
+namespace decimalops {
+
+static Decimal128 CheckAndIncreaseScale(Decimal128 in, int32_t delta) {
+ return (delta <= 0) ? in : in.IncreaseScaleBy(delta);
+}
+
+static Decimal128 CheckAndReduceScale(Decimal128 in, int32_t delta) {
+ return (delta <= 0) ? in : in.ReduceScaleBy(delta);
+}
+
+/// Adjust x and y to the same scale, and add them.
+static Decimal128 AddFastPath(const Decimal128Full& x, const Decimal128Full& y,
+ int32_t out_scale) {
+ auto higher_scale = std::max(x.scale(), y.scale());
+
+ auto x_scaled = CheckAndIncreaseScale(x.value(), higher_scale - x.scale());
+ auto y_scaled = CheckAndIncreaseScale(y.value(), higher_scale - y.scale());
+ return x_scaled + y_scaled;
+}
+
+/// Add x and y, caller has ensured there can be no overflow.
+static Decimal128 AddNoOverflow(const Decimal128Full& x, const Decimal128Full& y,
+ int32_t out_scale) {
+ auto higher_scale = std::max(x.scale(), y.scale());
+ auto sum = AddFastPath(x, y, out_scale);
+ return CheckAndReduceScale(sum, higher_scale - out_scale);
+}
+
+/// Both x_value and y_value must be >= 0
+static Decimal128 AddLargePositive(const Decimal128Full& x, const Decimal128Full& y,
+ int32_t out_scale) {
+ DCHECK_GE(x.value(), 0);
+ DCHECK_GE(y.value(), 0);
+
+ // separate out whole/fractions.
+ Decimal128 x_left, x_right, y_left, y_right;
+ x.value().GetWholeAndFraction(x.scale(), &x_left, &x_right);
+ y.value().GetWholeAndFraction(y.scale(), &y_left, &y_right);
+
+ // Adjust fractional parts to higher scale.
+ auto higher_scale = std::max(x.scale(), y.scale());
+ auto x_right_scaled = CheckAndIncreaseScale(x_right, higher_scale - x.scale());
+ auto y_right_scaled = CheckAndIncreaseScale(y_right, higher_scale - y.scale());
+
+ Decimal128 right;
+ Decimal128 carry_to_left;
+ auto multiplier = Decimal128::GetScaleMultiplier(higher_scale);
+ if (x_right_scaled >= multiplier - y_right_scaled) {
+ right = x_right_scaled - (multiplier - y_right_scaled);
+ carry_to_left = 1;
+ } else {
+ right = x_right_scaled + y_right_scaled;
+ carry_to_left = 0;
+ }
+ right = CheckAndReduceScale(right, higher_scale - out_scale);
+
+ auto left = x_left + y_left + carry_to_left;
+ return (left * Decimal128::GetScaleMultiplier(out_scale)) + right;
+}
+
+/// x_value and y_value cannot be 0, and one must be positive and the other negative.
+static Decimal128 AddLargeNegative(const Decimal128Full& x, const Decimal128Full& y,
+ int32_t out_scale) {
+ DCHECK_NE(x.value(), 0);
+ DCHECK_NE(y.value(), 0);
+ DCHECK((x.value() < 0 && y.value() > 0) || (x.value() > 0 && y.value() < 0));
+
+ // separate out whole/fractions.
+ Decimal128 x_left, x_right, y_left, y_right;
+ x.value().GetWholeAndFraction(x.scale(), &x_left, &x_right);
+ y.value().GetWholeAndFraction(y.scale(), &y_left, &y_right);
+
+ // Adjust fractional parts to higher scale.
+ auto higher_scale = std::max(x.scale(), y.scale());
+ x_right = CheckAndIncreaseScale(x_right, higher_scale - x.scale());
+ y_right = CheckAndIncreaseScale(y_right, higher_scale - y.scale());
+
+ // Overflow not possible because one is +ve and the other is -ve.
+ auto left = x_left + y_left;
+ auto right = x_right + y_right;
+
+ // If the whole and fractional parts have different signs, then we need to make the
+ // fractional part have the same sign as the whole part. If either left or right is
+ // zero, then nothing needs to be done.
+ if (left < 0 && right > 0) {
+ left += 1;
+ right -= Decimal128::GetScaleMultiplier(higher_scale);
+ } else if (left > 0 && right < 0) {
+ left -= 1;
+ right += Decimal128::GetScaleMultiplier(higher_scale);
+ }
+ right = CheckAndReduceScale(right, higher_scale - out_scale);
+ return (left * Decimal128::GetScaleMultiplier(out_scale)) + right;
+}
+
+static Decimal128 AddLarge(const Decimal128Full& x, const Decimal128Full& y,
+ int32_t out_scale) {
+ if (x.value() >= 0 && y.value() >= 0) {
+ // both positive or 0
+ return AddLargePositive(x, y, out_scale);
+ } else if (x.value() <= 0 && y.value() <= 0) {
+ // both negative or 0
+ Decimal128Full x_neg(-x.value(), x.precision(), x.scale());
+ Decimal128Full y_neg(-y.value(), y.precision(), y.scale());
+ return -AddLargePositive(x_neg, y_neg, out_scale);
+ } else {
+ // one positive and the other negative
+ return AddLargeNegative(x, y, out_scale);
+ }
+}
+
+// Suppose we have a number that requires x bits to be represented and we scale it up by
+// 10^scale_by. Let's say now y bits are required to represent it. This function returns
+// the maximum possible y - x for a given 'scale_by'.
+inline int32_t MaxBitsRequiredIncreaseAfterScaling(int32_t scale_by) {
+ // We rely on the following formula:
+ // bits_required(x * 10^y) <= bits_required(x) + floor(log2(10^y)) + 1
+ // We precompute floor(log2(10^x)) + 1 for x = 0, 1, 2...75, 76
+ DCHECK_GE(scale_by, 0);
+ DCHECK_LE(scale_by, 76);
+ static const int32_t floor_log2_plus_one[] = {
+ 0, 4, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40, 44, 47, 50,
+ 54, 57, 60, 64, 67, 70, 74, 77, 80, 84, 87, 90, 94, 97, 100, 103,
+ 107, 110, 113, 117, 120, 123, 127, 130, 133, 137, 140, 143, 147, 150, 153, 157,
+ 160, 163, 167, 170, 173, 177, 180, 183, 187, 190, 193, 196, 200, 203, 206, 210,
+ 213, 216, 220, 223, 226, 230, 233, 236, 240, 243, 246, 250, 253};
+ return floor_log2_plus_one[scale_by];
+}
+
+// If we have a number with 'num_lz' leading zeros, and we scale it up by 10^scale_by,
+// this function returns the minimum number of leading zeros the result can have.
+inline int32_t MinLeadingZerosAfterScaling(int32_t num_lz, int32_t scale_by) {
+ DCHECK_GE(scale_by, 0);
+ DCHECK_LE(scale_by, 76);
+ int32_t result = num_lz - MaxBitsRequiredIncreaseAfterScaling(scale_by);
+ return result;
+}
+
+// Returns the maximum possible number of bits required to represent num * 10^scale_by.
+inline int32_t MaxBitsRequiredAfterScaling(const Decimal128Full& num, int32_t scale_by) {
+ auto value = num.value();
+ auto value_abs = value.Abs();
+
+ int32_t num_occupied = 128 - value_abs.CountLeadingBinaryZeros();
+ DCHECK_GE(scale_by, 0);
+ DCHECK_LE(scale_by, 76);
+ return num_occupied + MaxBitsRequiredIncreaseAfterScaling(scale_by);
+}
+
+// Returns the minimum number of leading zero x or y would have after one of them gets
+// scaled up to match the scale of the other one.
+inline int32_t MinLeadingZeros(const Decimal128Full& x, const Decimal128Full& y) {
+ auto x_value = x.value();
+ auto x_value_abs = x_value.Abs();
+
+ auto y_value = y.value();
+ auto y_value_abs = y_value.Abs();
+
+ int32_t x_lz = x_value_abs.CountLeadingBinaryZeros();
+ int32_t y_lz = y_value_abs.CountLeadingBinaryZeros();
+ if (x.scale() < y.scale()) {
+ x_lz = MinLeadingZerosAfterScaling(x_lz, y.scale() - x.scale());
+ } else if (x.scale() > y.scale()) {
+ y_lz = MinLeadingZerosAfterScaling(y_lz, x.scale() - y.scale());
+ }
+ return std::min(x_lz, y_lz);
+}
+
+Decimal128 Add(const Decimal128Full& x, const Decimal128Full& y, int32_t out_precision,
+ int32_t out_scale) {
+ if (out_precision < DecimalTypeUtil::kMaxPrecision) {
+ // fast-path add
+ return AddFastPath(x, y, out_scale);
+ } else {
+ int32_t min_lz = MinLeadingZeros(x, y);
+ if (min_lz >= 3) {
+ // If both numbers have at least MIN_LZ leading zeros, we can add them directly
+ // without the risk of overflow.
+ // We want the result to have at least 2 leading zeros, which ensures that it fits
+ // into the maximum decimal because 2^126 - 1 < 10^38 - 1. If both x and y have at
+ // least 3 leading zeros, then we are guaranteed that the result will have at lest 2
+ // leading zeros.
+ return AddNoOverflow(x, y, out_scale);
+ } else {
+ // slower-version : add whole/fraction parts separately, and then, combine.
+ return AddLarge(x, y, out_scale);
+ }
+ }
+}
+
+} // namespace decimalops
+} // namespace gandiva
diff --git a/cpp/src/gandiva/literal_holder.h b/cpp/src/gandiva/precompiled/decimal_ops.h
similarity index 65%
copy from cpp/src/gandiva/literal_holder.h
copy to cpp/src/gandiva/precompiled/decimal_ops.h
index 0a65ea2..25f094e 100644
--- a/cpp/src/gandiva/literal_holder.h
+++ b/cpp/src/gandiva/precompiled/decimal_ops.h
@@ -15,19 +15,23 @@
// specific language governing permissions and limitations
// under the License.
-#ifndef GANDIVA_LITERAL_HOLDER
-#define GANDIVA_LITERAL_HOLDER
+#ifndef DECIMAL_SQL_H
+#define DECIMAL_SQL_H
+#include <cstdint>
#include <string>
-
-#include <boost/variant.hpp>
+#include "gandiva/decimal_full.h"
namespace gandiva {
+namespace decimalops {
-using LiteralHolder =
- boost::variant<bool, float, double, int8_t, int16_t, int32_t, int64_t, uint8_t,
- uint16_t, uint32_t, uint64_t, std::string>;
+/// Return the sum of 'x' and 'y'.
+/// out_precision and out_scale are passed along for efficiency, they must match
+/// the rules in DecimalTypeSql::GetResultType.
+Decimal128 Add(const Decimal128Full& x, const Decimal128Full& y, int32_t out_precision,
+ int32_t out_scale);
+} // namespace decimalops
} // namespace gandiva
-#endif // GANDIVA_LITERAL_HOLDER
+#endif // DECIMAL_SQL_H
diff --git a/cpp/src/gandiva/precompiled/decimal_ops_test.cc b/cpp/src/gandiva/precompiled/decimal_ops_test.cc
new file mode 100644
index 0000000..7daf734
--- /dev/null
+++ b/cpp/src/gandiva/precompiled/decimal_ops_test.cc
@@ -0,0 +1,75 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+#include <algorithm>
+#include <memory>
+
+#include "arrow/test-util.h"
+#include "gandiva/decimal_type_util.h"
+#include "gandiva/precompiled/decimal_ops.h"
+#include "gandiva/precompiled/types.h"
+
+namespace gandiva {
+
+class TestDecimalSql : public ::testing::Test {
+ protected:
+ static void AddAndVerify(const Decimal128Full& x, const Decimal128Full& y,
+ const Decimal128Full& expected);
+};
+
+#define EXPECT_DECIMAL_EQ(x, y, expected, actual) \
+ EXPECT_EQ(expected, actual) << (x).ToString() << " + " << (y).ToString() \
+ << " expected : " << expected.ToString() << " actual " \
+ << actual.ToString()
+
+void TestDecimalSql::AddAndVerify(const Decimal128Full& x, const Decimal128Full& y,
+ const Decimal128Full& expected) {
+ auto t1 = std::make_shared<arrow::Decimal128Type>(x.precision(), x.scale());
+ auto t2 = std::make_shared<arrow::Decimal128Type>(y.precision(), y.scale());
+
+ Decimal128TypePtr out_type;
+ EXPECT_OK(DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd, {t1, t2}, &out_type));
+
+ auto out_value = decimalops::Add(x, y, out_type->precision(), out_type->scale());
+ EXPECT_DECIMAL_EQ(x, y, expected,
+ Decimal128Full(out_value, out_type->precision(), out_type->scale()));
+}
+
+TEST_F(TestDecimalSql, Add) {
+ // fast-path
+ AddAndVerify(Decimal128Full{"201", 30, 3}, // x
+ Decimal128Full{"301", 30, 3}, // y
+ Decimal128Full{"502", 31, 3}); // expected
+
+ // max precision
+ AddAndVerify(Decimal128Full{"09999999999999999999999999999999000000", 38, 5}, // x
+ Decimal128Full{"100", 38, 7}, // y
+ Decimal128Full{"99999999999999999999999999999990000010", 38, 6});
+
+ // Both -ve
+ AddAndVerify(Decimal128Full{"-201", 30, 3}, // x
+ Decimal128Full{"-301", 30, 2}, // y
+ Decimal128Full{"-3211", 32, 3}); // expected
+
+ // -ve and max precision
+ AddAndVerify(Decimal128Full{"-09999999999999999999999999999999000000", 38, 5}, // x
+ Decimal128Full{"-100", 38, 7}, // y
+ Decimal128Full{"-99999999999999999999999999999990000010", 38, 6});
+}
+
+} // namespace gandiva
diff --git a/cpp/src/gandiva/precompiled/decimal_wrapper.cc b/cpp/src/gandiva/precompiled/decimal_wrapper.cc
new file mode 100644
index 0000000..fdc751f
--- /dev/null
+++ b/cpp/src/gandiva/precompiled/decimal_wrapper.cc
@@ -0,0 +1,43 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "gandiva/precompiled/decimal_ops.h"
+#include "gandiva/precompiled/types.h"
+
+extern "C" {
+
+/// TODO : Passing around structs in IR can be fragile due to c-abi compatibility issues.
+/// This seems to work for now, but will need to revisit if we hit issues.
+struct DecimalSplit {
+ int64_t high_bits;
+ uint64_t low_bits;
+};
+
+FORCE_INLINE
+DecimalSplit add_large_decimal128_decimal128(int64_t x_high, uint64_t x_low,
+ int32_t x_precision, int32_t x_scale,
+ int64_t y_high, uint64_t y_low,
+ int32_t y_precision, int32_t y_scale,
+ int32_t out_precision, int32_t out_scale) {
+ gandiva::Decimal128Full x(x_high, x_low, x_precision, x_scale);
+ gandiva::Decimal128Full y(y_high, y_low, y_precision, y_scale);
+
+ arrow::Decimal128 out = gandiva::decimalops::Add(x, y, out_precision, out_scale);
+ return DecimalSplit{out.high_bits(), out.low_bits()};
+}
+
+} // extern "C"
diff --git a/cpp/src/gandiva/projector.cc b/cpp/src/gandiva/projector.cc
index 4cb352f..8fc5b8c 100644
--- a/cpp/src/gandiva/projector.cc
+++ b/cpp/src/gandiva/projector.cc
@@ -143,7 +143,8 @@ Status Projector::Evaluate(const arrow::RecordBatch& batch, arrow::MemoryPool* p
// TODO : handle variable-len vectors
Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records,
arrow::MemoryPool* pool, ArrayDataPtr* array_data) {
- ARROW_RETURN_IF(!arrow::is_primitive(type->id()),
+ const auto* fw_type = dynamic_cast<const arrow::FixedWidthType*>(type.get());
+ ARROW_RETURN_IF(fw_type == nullptr,
Status::Invalid("Unsupported output data type ", type));
std::shared_ptr<arrow::Buffer> null_bitmap;
@@ -151,8 +152,7 @@ Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records,
ARROW_RETURN_NOT_OK(arrow::AllocateBuffer(pool, bitmap_bytes, &null_bitmap));
std::shared_ptr<arrow::Buffer> data;
- const auto& fw_type = dynamic_cast<const arrow::FixedWidthType&>(*type);
- int64_t data_len = arrow::BitUtil::BytesForBits(num_records * fw_type.bit_width());
+ int64_t data_len = arrow::BitUtil::BytesForBits(num_records * fw_type->bit_width());
ARROW_RETURN_NOT_OK(arrow::AllocateBuffer(pool, data_len, &data));
// This is not strictly required but valgrind gets confused and detects this
diff --git a/cpp/src/gandiva/proto/Types.proto b/cpp/src/gandiva/proto/Types.proto
index ac19d0f..7474065 100644
--- a/cpp/src/gandiva/proto/Types.proto
+++ b/cpp/src/gandiva/proto/Types.proto
@@ -146,6 +146,13 @@ message BinaryNode {
optional bytes value = 1;
}
+message DecimalNode {
+ optional string value = 1;
+ optional int32 precision = 2;
+ optional int32 scale = 3;
+}
+
+
message TreeNode {
optional FieldNode fieldNode = 1;
optional FunctionNode fnNode = 2;
@@ -164,6 +171,7 @@ message TreeNode {
optional DoubleNode doubleNode = 16;
optional StringNode stringNode = 17;
optional BinaryNode binaryNode = 18;
+ optional DecimalNode decimalNode = 19;
}
message ExpressionRoot {
diff --git a/cpp/src/gandiva/tests/CMakeLists.txt b/cpp/src/gandiva/tests/CMakeLists.txt
index 9558fc0..b47e5fd 100644
--- a/cpp/src/gandiva/tests/CMakeLists.txt
+++ b/cpp/src/gandiva/tests/CMakeLists.txt
@@ -27,11 +27,17 @@ ADD_GANDIVA_TEST(to_string_test)
ADD_GANDIVA_TEST(hash_test)
ADD_GANDIVA_TEST(in_expr_test)
ADD_GANDIVA_TEST(null_validity_test)
+ADD_GANDIVA_TEST(decimal_test)
+ADD_GANDIVA_TEST(decimal_single_test)
ADD_GANDIVA_TEST(projector_test_static
SOURCES projector_test.cc
USE_STATIC_LINKING)
-ADD_BENCHMARK(micro_benchmarks
+ADD_GANDIVA_TEST(decimal_single_test_static
+ SOURCES decimal_single_test.cc
+ USE_STATIC_LINKING)
+
+ADD_ARROW_BENCHMARK(micro_benchmarks
PREFIX "gandiva"
EXTRA_LINK_LIBS gandiva_static)
diff --git a/cpp/src/gandiva/tests/decimal_single_test.cc b/cpp/src/gandiva/tests/decimal_single_test.cc
new file mode 100644
index 0000000..728ccb7
--- /dev/null
+++ b/cpp/src/gandiva/tests/decimal_single_test.cc
@@ -0,0 +1,224 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <sstream>
+
+#include <gtest/gtest.h>
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+
+#include "gandiva/decimal_full.h"
+#include "gandiva/decimal_type_util.h"
+#include "gandiva/projector.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+using arrow::Decimal128;
+
+namespace gandiva {
+
+#define EXPECT_DECIMAL_SUM_EQUALS(x, y, expected, actual) \
+ EXPECT_EQ(expected, actual) << (x).ToString() << " + " << (y).ToString() \
+ << " expected : " << (expected).ToString() \
+ << " actual : " << (actual).ToString();
+
+Decimal128Full decimal_literal(const char* value, int precision, int scale) {
+ std::string value_string = std::string(value);
+ return Decimal128Full(value_string, precision, scale);
+}
+
+class TestDecimalOps : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ ArrayPtr MakeDecimalVector(const Decimal128Full& in);
+ void AddAndVerify(const Decimal128Full& x, const Decimal128Full& y,
+ const Decimal128Full& expected);
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+ArrayPtr TestDecimalOps::MakeDecimalVector(const Decimal128Full& in) {
+ std::vector<arrow::Decimal128> ret;
+
+ Decimal128 decimal_value = in.value();
+
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(in.precision(), in.scale());
+ return MakeArrowArrayDecimal(decimal_type, {decimal_value}, {true});
+}
+
+void TestDecimalOps::AddAndVerify(const Decimal128Full& x, const Decimal128Full& y,
+ const Decimal128Full& expected) {
+ auto x_type = std::make_shared<arrow::Decimal128Type>(x.precision(), x.scale());
+ auto y_type = std::make_shared<arrow::Decimal128Type>(y.precision(), y.scale());
+ auto field_x = field("x", x_type);
+ auto field_y = field("y", y_type);
+ auto schema = arrow::schema({field_x, field_y});
+
+ Decimal128TypePtr output_type;
+ auto status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd, {x_type, y_type},
+ &output_type);
+ EXPECT_OK(status);
+
+ // output fields
+ auto res = field("res", output_type);
+
+ // build expression : x + y
+ auto expr = TreeExprBuilder::MakeExpression("add", {field_x, field_y}, res);
+
+ // Build a projector for the expression.
+ std::shared_ptr<Projector> projector;
+ status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_OK(status);
+
+ // Create a row-batch with some sample data
+ auto array_a = MakeDecimalVector(x);
+ auto array_b = MakeDecimalVector(y);
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, 1 /*num_records*/, {array_a, array_b});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_OK(status);
+
+ // Validate results
+ auto out_array = dynamic_cast<arrow::Decimal128Array*>(outputs[0].get());
+ const Decimal128 out_value(out_array->GetValue(0));
+
+ auto dtype = dynamic_cast<arrow::Decimal128Type*>(out_array->type().get());
+ std::string value_string = out_value.ToString(0);
+ Decimal128Full actual{value_string, dtype->precision(), dtype->scale()};
+
+ EXPECT_DECIMAL_SUM_EQUALS(x, y, expected, actual);
+}
+
+TEST_F(TestDecimalOps, TestAdd) {
+ // fast-path
+ AddAndVerify(decimal_literal("201", 30, 3), // x
+ decimal_literal("301", 30, 3), // y
+ decimal_literal("502", 31, 3)); // expected
+
+ AddAndVerify(decimal_literal("201", 30, 3), // x
+ decimal_literal("301", 30, 2), // y
+ decimal_literal("3211", 32, 3)); // expected
+
+ AddAndVerify(decimal_literal("201", 30, 3), // x
+ decimal_literal("301", 30, 4), // y
+ decimal_literal("2311", 32, 4)); // expected
+
+ // max precision, but no overflow
+ AddAndVerify(decimal_literal("201", 38, 3), // x
+ decimal_literal("301", 38, 3), // y
+ decimal_literal("502", 38, 3)); // expected
+
+ AddAndVerify(decimal_literal("201", 38, 3), // x
+ decimal_literal("301", 38, 2), // y
+ decimal_literal("3211", 38, 3)); // expected
+
+ AddAndVerify(decimal_literal("201", 38, 3), // x
+ decimal_literal("301", 38, 4), // y
+ decimal_literal("2311", 38, 4)); // expected
+
+ AddAndVerify(decimal_literal("201", 38, 3), // x
+ decimal_literal("301", 38, 7), // y
+ decimal_literal("201030", 38, 6)); // expected
+
+ AddAndVerify(decimal_literal("1201", 38, 3), // x
+ decimal_literal("1801", 38, 3), // y
+ decimal_literal("3002", 38, 3)); // carry-over from fractional
+
+ // max precision
+ AddAndVerify(decimal_literal("09999999999999999999999999999999000000", 38, 5), // x
+ decimal_literal("100", 38, 7), // y
+ decimal_literal("99999999999999999999999999999990000010", 38, 6));
+
+ AddAndVerify(decimal_literal("-09999999999999999999999999999999000000", 38, 5), // x
+ decimal_literal("100", 38, 7), // y
+ decimal_literal("-99999999999999999999999999999989999990", 38, 6));
+
+ AddAndVerify(decimal_literal("09999999999999999999999999999999000000", 38, 5), // x
+ decimal_literal("-100", 38, 7), // y
+ decimal_literal("99999999999999999999999999999989999990", 38, 6));
+
+ AddAndVerify(decimal_literal("-09999999999999999999999999999999000000", 38, 5), // x
+ decimal_literal("-100", 38, 7), // y
+ decimal_literal("-99999999999999999999999999999990000010", 38, 6));
+
+ AddAndVerify(decimal_literal("09999999999999999999999999999999999999", 38, 6), // x
+ decimal_literal("89999999999999999999999999999999999999", 38, 7), // y
+ decimal_literal("18999999999999999999999999999999999999", 38, 6));
+
+ // Both -ve
+ AddAndVerify(decimal_literal("-201", 30, 3), // x
+ decimal_literal("-301", 30, 2), // y
+ decimal_literal("-3211", 32, 3)); // expected
+
+ AddAndVerify(decimal_literal("-201", 38, 3), // x
+ decimal_literal("-301", 38, 4), // y
+ decimal_literal("-2311", 38, 4)); // expected
+
+ // Mix of +ve and -ve
+ AddAndVerify(decimal_literal("-201", 30, 3), // x
+ decimal_literal("301", 30, 2), // y
+ decimal_literal("2809", 32, 3)); // expected
+
+ AddAndVerify(decimal_literal("-201", 38, 3), // x
+ decimal_literal("301", 38, 4), // y
+ decimal_literal("-1709", 38, 4)); // expected
+
+ AddAndVerify(decimal_literal("201", 38, 3), // x
+ decimal_literal("-301", 38, 7), // y
+ decimal_literal("200970", 38, 6)); // expected
+
+ AddAndVerify(decimal_literal("-1901", 38, 4), // x
+ decimal_literal("1801", 38, 4), // y
+ decimal_literal("-100", 38, 4)); // expected
+
+ AddAndVerify(decimal_literal("1801", 38, 4), // x
+ decimal_literal("-1901", 38, 4), // y
+ decimal_literal("-100", 38, 4)); // expected
+
+ // rounding +ve
+ AddAndVerify(decimal_literal("1000999", 38, 6), // x
+ decimal_literal("10000999", 38, 7), // y
+ decimal_literal("2001099", 38, 6));
+
+ AddAndVerify(decimal_literal("1000999", 38, 6), // x
+ decimal_literal("10000995", 38, 7), // y
+ decimal_literal("2001099", 38, 6));
+
+ AddAndVerify(decimal_literal("1000999", 38, 6), // x
+ decimal_literal("10000992", 38, 7), // y
+ decimal_literal("2001098", 38, 6));
+
+ // rounding -ve
+ AddAndVerify(decimal_literal("-1000999", 38, 6), // x
+ decimal_literal("-10000999", 38, 7), // y
+ decimal_literal("-2001099", 38, 6));
+
+ AddAndVerify(decimal_literal("-1000999", 38, 6), // x
+ decimal_literal("-10000995", 38, 7), // y
+ decimal_literal("-2001099", 38, 6));
+
+ AddAndVerify(decimal_literal("-1000999", 38, 6), // x
+ decimal_literal("-10000992", 38, 7), // y
+ decimal_literal("-2001098", 38, 6));
+}
+} // namespace gandiva
diff --git a/cpp/src/gandiva/tests/decimal_test.cc b/cpp/src/gandiva/tests/decimal_test.cc
new file mode 100644
index 0000000..f048fd2
--- /dev/null
+++ b/cpp/src/gandiva/tests/decimal_test.cc
@@ -0,0 +1,237 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <sstream>
+
+#include <gtest/gtest.h>
+#include "arrow/memory_pool.h"
+#include "arrow/status.h"
+#include "arrow/util/decimal.h"
+
+#include "gandiva/decimal_type_util.h"
+#include "gandiva/projector.h"
+#include "gandiva/tests/test_util.h"
+#include "gandiva/tree_expr_builder.h"
+
+using arrow::Decimal128;
+
+namespace gandiva {
+
+class TestDecimal : public ::testing::Test {
+ public:
+ void SetUp() { pool_ = arrow::default_memory_pool(); }
+
+ std::vector<Decimal128> MakeDecimalVector(std::vector<std::string> values,
+ int32_t scale);
+
+ protected:
+ arrow::MemoryPool* pool_;
+};
+
+std::vector<Decimal128> TestDecimal::MakeDecimalVector(std::vector<std::string> values,
+ int32_t scale) {
+ std::vector<arrow::Decimal128> ret;
+ for (auto str : values) {
+ Decimal128 str_value;
+ int32_t str_precision;
+ int32_t str_scale;
+
+ auto status = Decimal128::FromString(str, &str_value, &str_precision, &str_scale);
+ DCHECK_OK(status);
+
+ Decimal128 scaled_value;
+ status = str_value.Rescale(str_scale, scale, &scaled_value);
+ ret.push_back(scaled_value);
+ }
+ return ret;
+}
+
+TEST_F(TestDecimal, TestSimple) {
+ // schema for input fields
+ constexpr int32_t precision = 36;
+ constexpr int32_t scale = 18;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto field_a = field("a", decimal_type);
+ auto field_b = field("b", decimal_type);
+ auto field_c = field("c", decimal_type);
+ auto schema = arrow::schema({field_a, field_b, field_c});
+
+ Decimal128TypePtr add2_type;
+ auto status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd,
+ {decimal_type, decimal_type}, &add2_type);
+
+ Decimal128TypePtr output_type;
+ status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd,
+ {add2_type, decimal_type}, &output_type);
+
+ // output fields
+ auto res = field("res0", output_type);
+
+ // build expression : a + b + c
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto node_b = TreeExprBuilder::MakeField(field_b);
+ auto node_c = TreeExprBuilder::MakeField(field_c);
+ auto add2 = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, add2_type);
+ auto add3 = TreeExprBuilder::MakeFunction("add", {add2, node_c}, output_type);
+ auto expr = TreeExprBuilder::MakeExpression(add3, res);
+
+ // Build a projector for the expression.
+ std::shared_ptr<Projector> projector;
+ status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ DCHECK_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"1", "2", "3", "4"}, scale),
+ {false, true, true, true});
+ auto array_b =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"2", "3", "4", "5"}, scale),
+ {false, true, true, true});
+ auto array_c =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"3", "4", "5", "6"}, scale),
+ {true, true, true, true});
+
+ // prepare input record batch
+ auto in_batch =
+ arrow::RecordBatch::Make(schema, num_records, {array_a, array_b, array_c});
+
+ auto expected =
+ MakeArrowArrayDecimal(output_type, MakeDecimalVector({"6", "9", "12", "15"}, scale),
+ {false, true, true, true});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ DCHECK_OK(status);
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(expected, outputs[0]);
+}
+
+TEST_F(TestDecimal, TestLiteral) {
+ // schema for input fields
+ constexpr int32_t precision = 36;
+ constexpr int32_t scale = 18;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto field_a = field("a", decimal_type);
+ auto schema = arrow::schema({
+ field_a,
+ });
+
+ Decimal128TypePtr add2_type;
+ auto status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd,
+ {decimal_type, decimal_type}, &add2_type);
+
+ // output fields
+ auto res = field("res0", add2_type);
+
+ // build expression : a + b + c
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ static std::string decimal_point_six = "6";
+ Decimal128Full literal(decimal_point_six, 2, 1);
+ auto node_b = TreeExprBuilder::MakeDecimalLiteral(literal);
+ auto add2 = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, add2_type);
+ auto expr = TreeExprBuilder::MakeExpression(add2, res);
+
+ // Build a projector for the expression.
+ std::shared_ptr<Projector> projector;
+ status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ DCHECK_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"1", "2", "3", "4"}, scale),
+ {false, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a});
+
+ auto expected = MakeArrowArrayDecimal(
+ add2_type, MakeDecimalVector({"1.6", "2.6", "3.6", "4.6"}, scale),
+ {false, true, true, true});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ DCHECK_OK(status);
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(expected, outputs[0]);
+}
+
+TEST_F(TestDecimal, TestIfElse) {
+ // schema for input fields
+ constexpr int32_t precision = 36;
+ constexpr int32_t scale = 18;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto field_a = field("a", decimal_type);
+ auto field_b = field("b", decimal_type);
+ auto field_c = field("c", arrow::boolean());
+ auto schema = arrow::schema({field_a, field_b, field_c});
+
+ // output fields
+ auto field_result = field("res", decimal_type);
+
+ // build expression.
+ // if (c)
+ // a
+ // else
+ // b
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto node_b = TreeExprBuilder::MakeField(field_b);
+ auto node_c = TreeExprBuilder::MakeField(field_c);
+ auto if_node = TreeExprBuilder::MakeIf(node_c, node_a, node_b, decimal_type);
+
+ auto expr = TreeExprBuilder::MakeExpression(if_node, field_result);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ Status status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ DCHECK_OK(status);
+
+ // Create a row-batch with some sample data
+ int num_records = 4;
+ auto array_a =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"1", "2", "3", "4"}, scale),
+ {false, true, true, true});
+ auto array_b =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"2", "3", "4", "5"}, scale),
+ {true, true, true, true});
+
+ auto array_c = MakeArrowArrayBool({true, false, true, false}, {true, true, true, true});
+
+ // expected output
+ auto exp =
+ MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"0", "3", "3", "5"}, scale),
+ {false, true, true, true});
+
+ // prepare input record batch
+ auto in_batch =
+ arrow::RecordBatch::Make(schema, num_records, {array_a, array_b, array_c});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ DCHECK_OK(status);
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
+} // namespace gandiva
diff --git a/cpp/src/gandiva/tests/generate_data.h b/cpp/src/gandiva/tests/generate_data.h
index 01665b8..3980575 100644
--- a/cpp/src/gandiva/tests/generate_data.h
+++ b/cpp/src/gandiva/tests/generate_data.h
@@ -19,6 +19,8 @@
#include <random>
#include <string>
+#include "arrow/util/decimal.h"
+
#ifndef GANDIVA_GENERATE_DATA_H
#define GANDIVA_GENERATE_DATA_H
@@ -79,6 +81,24 @@ class Int64DataGenerator : public DataGenerator<int64_t> {
Random random_;
};
+class Decimal128DataGenerator : public DataGenerator<arrow::Decimal128> {
+ public:
+ explicit Decimal128DataGenerator(bool large) : large_(large) {}
+
+ arrow::Decimal128 GenerateData() {
+ uint64_t low = random_.next();
+ int64_t high = random_.next();
+ if (large_) {
+ high += (1ull << 62);
+ }
+ return arrow::Decimal128(high, low);
+ }
+
+ protected:
+ bool large_;
+ Random random_;
+};
+
class FastUtf8DataGenerator : public DataGenerator<std::string> {
public:
explicit FastUtf8DataGenerator(int max_len) : max_len_(max_len), cur_char_('a') {}
diff --git a/cpp/src/gandiva/tests/micro_benchmarks.cc b/cpp/src/gandiva/tests/micro_benchmarks.cc
index ce86bf0..e0794a2 100644
--- a/cpp/src/gandiva/tests/micro_benchmarks.cc
+++ b/cpp/src/gandiva/tests/micro_benchmarks.cc
@@ -19,6 +19,7 @@
#include "arrow/memory_pool.h"
#include "arrow/status.h"
#include "benchmark/benchmark.h"
+#include "gandiva/decimal_type_util.h"
#include "gandiva/projector.h"
#include "gandiva/tests/test_util.h"
#include "gandiva/tests/timed_evaluate.h"
@@ -31,10 +32,6 @@ using arrow::int32;
using arrow::int64;
using arrow::utf8;
-// TODO : the base numbers are from a mac. they need to be caliberated
-// for the hardware used by travis.
-float tolerance_ratio = 6.0;
-
static void TimedTestAdd3(benchmark::State& state) {
// schema for input fields
auto field0 = field("f0", int64());
@@ -280,6 +277,119 @@ static void TimedTestInExpr(benchmark::State& state) {
ASSERT_OK(status);
}
+static void DoDecimalAdd3(benchmark::State& state, int32_t precision, int32_t scale,
+ bool large = false) {
+ // schema for input fields
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto field0 = field("f0", decimal_type);
+ auto field1 = field("f1", decimal_type);
+ auto field2 = field("f2", decimal_type);
+ auto schema = arrow::schema({field0, field1, field2});
+
+ Decimal128TypePtr add2_type;
+ auto status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd,
+ {decimal_type, decimal_type}, &add2_type);
+
+ Decimal128TypePtr output_type;
+ status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd,
+ {add2_type, decimal_type}, &output_type);
+
+ // output field
+ auto field_sum = field("add", output_type);
+
+ // Build expression
+ auto part_sum = TreeExprBuilder::MakeFunction(
+ "add", {TreeExprBuilder::MakeField(field1), TreeExprBuilder::MakeField(field2)},
+ add2_type);
+ auto sum = TreeExprBuilder::MakeFunction(
+ "add", {TreeExprBuilder::MakeField(field0), part_sum}, output_type);
+
+ auto sum_expr = TreeExprBuilder::MakeExpression(sum, field_sum);
+
+ std::shared_ptr<Projector> projector;
+ status = Projector::Make(schema, {sum_expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ Decimal128DataGenerator data_generator(large);
+ ProjectEvaluator evaluator(projector);
+
+ status = TimedEvaluate<arrow::Decimal128Type, arrow::Decimal128>(
+ schema, evaluator, data_generator, arrow::default_memory_pool(), 1 * MILLION,
+ 16 * THOUSAND, state);
+ ASSERT_OK(status);
+}
+
+static void DoDecimalAdd2(benchmark::State& state, int32_t precision, int32_t scale,
+ bool large = false) {
+ // schema for input fields
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+ auto field0 = field("f0", decimal_type);
+ auto field1 = field("f1", decimal_type);
+ auto schema = arrow::schema({field0, field1});
+
+ Decimal128TypePtr output_type;
+ auto status = DecimalTypeUtil::GetResultType(
+ DecimalTypeUtil::kOpAdd, {decimal_type, decimal_type}, &output_type);
+
+ // output field
+ auto field_sum = field("add", output_type);
+
+ // Build expression
+ auto sum = TreeExprBuilder::MakeExpression("add", {field0, field1}, field_sum);
+
+ std::shared_ptr<Projector> projector;
+ status = Projector::Make(schema, {sum}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok());
+
+ Decimal128DataGenerator data_generator(large);
+ ProjectEvaluator evaluator(projector);
+
+ status = TimedEvaluate<arrow::Decimal128Type, arrow::Decimal128>(
+ schema, evaluator, data_generator, arrow::default_memory_pool(), 1 * MILLION,
+ 16 * THOUSAND, state);
+ ASSERT_OK(status);
+}
+
+static void DecimalAdd2Fast(benchmark::State& state) {
+ // use lesser precision to test the fast-path
+ DoDecimalAdd2(state, DecimalTypeUtil::kMaxPrecision - 6, 18);
+}
+
+static void DecimalAdd2LeadingZeroes(benchmark::State& state) {
+ // use max precision to test the large-integer-path
+ DoDecimalAdd2(state, DecimalTypeUtil::kMaxPrecision, 6);
+}
+
+static void DecimalAdd2LeadingZeroesWithDiv(benchmark::State& state) {
+ // use max precision to test the large-integer-path
+ DoDecimalAdd2(state, DecimalTypeUtil::kMaxPrecision, 18);
+}
+
+static void DecimalAdd2Large(benchmark::State& state) {
+ // use max precision to test the large-integer-path
+ DoDecimalAdd2(state, DecimalTypeUtil::kMaxPrecision, 18, true);
+}
+
+static void DecimalAdd3Fast(benchmark::State& state) {
+ // use lesser precision to test the fast-path
+ DoDecimalAdd3(state, DecimalTypeUtil::kMaxPrecision - 6, 18);
+}
+
+static void DecimalAdd3LeadingZeroes(benchmark::State& state) {
+ // use max precision to test the large-integer-path
+ DoDecimalAdd3(state, DecimalTypeUtil::kMaxPrecision, 6);
+}
+
+static void DecimalAdd3LeadingZeroesWithDiv(benchmark::State& state) {
+ // use max precision to test the large-integer-path
+ DoDecimalAdd3(state, DecimalTypeUtil::kMaxPrecision, 18);
+}
+
+static void DecimalAdd3Large(benchmark::State& state) {
+ // use max precision to test the large-integer-path
+ DoDecimalAdd3(state, DecimalTypeUtil::kMaxPrecision, 18, true);
+}
+
BENCHMARK(TimedTestAdd3)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
BENCHMARK(TimedTestBigNested)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
BENCHMARK(TimedTestBigNested)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
@@ -289,5 +399,13 @@ BENCHMARK(TimedTestFilterLike)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
BENCHMARK(TimedTestAllocs)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
BENCHMARK(TimedTestMultiOr)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
BENCHMARK(TimedTestInExpr)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(DecimalAdd2Fast)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(DecimalAdd2LeadingZeroes)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(DecimalAdd2LeadingZeroesWithDiv)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(DecimalAdd2Large)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(DecimalAdd3Fast)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(DecimalAdd3LeadingZeroes)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(DecimalAdd3LeadingZeroesWithDiv)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
+BENCHMARK(DecimalAdd3Large)->MinTime(1.0)->Unit(benchmark::kMicrosecond);
} // namespace gandiva
diff --git a/cpp/src/gandiva/tests/test_util.h b/cpp/src/gandiva/tests/test_util.h
index 72b45b1..0e0e27a 100644
--- a/cpp/src/gandiva/tests/test_util.h
+++ b/cpp/src/gandiva/tests/test_util.h
@@ -21,6 +21,7 @@
#include <vector>
#include "arrow/test-util.h"
#include "gandiva/arrow.h"
+#include "gandiva/configuration.h"
#ifndef GANDIVA_TEST_UTIL_H
#define GANDIVA_TEST_UTIL_H
@@ -47,6 +48,14 @@ static ArrayPtr MakeArrowArray(std::vector<C_TYPE> values) {
}
template <typename TYPE, typename C_TYPE>
+static ArrayPtr MakeArrowArray(const std::shared_ptr<arrow::DataType>& type,
+ std::vector<C_TYPE> values, std::vector<bool> validity) {
+ ArrayPtr out;
+ arrow::ArrayFromVector<TYPE, C_TYPE>(type, validity, values, &out);
+ return out;
+}
+
+template <typename TYPE, typename C_TYPE>
static ArrayPtr MakeArrowTypeArray(const std::shared_ptr<arrow::DataType>& type,
const std::vector<C_TYPE>& values,
const std::vector<bool>& validity) {
@@ -68,11 +77,16 @@ static ArrayPtr MakeArrowTypeArray(const std::shared_ptr<arrow::DataType>& type,
#define MakeArrowArrayFloat64 MakeArrowArray<arrow::DoubleType, double>
#define MakeArrowArrayUtf8 MakeArrowArray<arrow::StringType, std::string>
#define MakeArrowArrayBinary MakeArrowArray<arrow::BinaryType, std::string>
+#define MakeArrowArrayDecimal MakeArrowArray<arrow::Decimal128Type, arrow::Decimal128>
#define EXPECT_ARROW_ARRAY_EQUALS(a, b) \
EXPECT_TRUE((a)->Equals(b)) << "expected array: " << (a)->ToString() \
<< " actual array: " << (b)->ToString();
+#define EXPECT_ARROW_TYPE_EQUALS(a, b) \
+ EXPECT_TRUE((a)->Equals(b)) << "expected type: " << (a)->ToString() \
+ << " actual type: " << (b)->ToString();
+
std::shared_ptr<Configuration> TestConfiguration() {
auto builder = ConfigurationBuilder();
builder.set_byte_code_file_path(GANDIVA_BYTE_COMPILE_FILE_PATH);
diff --git a/cpp/src/gandiva/tests/timed_evaluate.h b/cpp/src/gandiva/tests/timed_evaluate.h
index dab47c2..9db7d88 100644
--- a/cpp/src/gandiva/tests/timed_evaluate.h
+++ b/cpp/src/gandiva/tests/timed_evaluate.h
@@ -100,7 +100,9 @@ Status TimedEvaluate(SchemaPtr schema, BaseEvaluator& evaluator,
for (int col = 0; col < num_fields; col++) {
std::vector<C_TYPE> data = GenerateData<C_TYPE>(batch_size, data_generator);
std::vector<bool> validity(batch_size, true);
- ArrayPtr col_data = MakeArrowArray<TYPE, C_TYPE>(data, validity);
+ ArrayPtr col_data =
+ MakeArrowArray<TYPE, C_TYPE>(schema->field(col)->type(), data, validity);
+
columns.push_back(col_data);
}
diff --git a/cpp/src/gandiva/tree_expr_builder.cc b/cpp/src/gandiva/tree_expr_builder.cc
index 86a2824..23a49e2 100644
--- a/cpp/src/gandiva/tree_expr_builder.cc
+++ b/cpp/src/gandiva/tree_expr_builder.cc
@@ -19,6 +19,7 @@
#include <utility>
+#include "gandiva/decimal_type_util.h"
#include "gandiva/gandiva_aliases.h"
#include "gandiva/node.h"
@@ -49,6 +50,11 @@ NodePtr TreeExprBuilder::MakeBinaryLiteral(const std::string& value) {
return std::make_shared<LiteralNode>(arrow::binary(), LiteralHolder(value), false);
}
+NodePtr TreeExprBuilder::MakeDecimalLiteral(const Decimal128Full& value) {
+ return std::make_shared<LiteralNode>(arrow::decimal(value.precision(), value.scale()),
+ LiteralHolder(value), false);
+}
+
NodePtr TreeExprBuilder::MakeNull(DataTypePtr data_type) {
static const std::string empty;
@@ -92,6 +98,10 @@ NodePtr TreeExprBuilder::MakeNull(DataTypePtr data_type) {
return std::make_shared<LiteralNode>(data_type, LiteralHolder((int64_t)0), true);
case arrow::Type::TIMESTAMP:
return std::make_shared<LiteralNode>(data_type, LiteralHolder((int64_t)0), true);
+ case arrow::Type::DECIMAL: {
+ Decimal128Full literal(0, 0);
+ return std::make_shared<LiteralNode>(data_type, LiteralHolder(literal), true);
+ }
default:
return nullptr;
}
diff --git a/cpp/src/gandiva/tree_expr_builder.h b/cpp/src/gandiva/tree_expr_builder.h
index cd261c8..ae5f7fb 100644
--- a/cpp/src/gandiva/tree_expr_builder.h
+++ b/cpp/src/gandiva/tree_expr_builder.h
@@ -23,7 +23,9 @@
#include <unordered_set>
#include <vector>
+#include "arrow/type.h"
#include "gandiva/condition.h"
+#include "gandiva/decimal_full.h"
#include "gandiva/expression.h"
namespace gandiva {
@@ -45,6 +47,7 @@ class TreeExprBuilder {
static NodePtr MakeLiteral(double value);
static NodePtr MakeStringLiteral(const std::string& value);
static NodePtr MakeBinaryLiteral(const std::string& value);
+ static NodePtr MakeDecimalLiteral(const Decimal128Full& value);
/// \brief create a node on a null literal.
/// returns null if data_type is null or if it's not a supported datatype.
diff --git a/cpp/valgrind.supp b/cpp/valgrind.supp
index 08076aa..8d2d5da 100644
--- a/cpp/valgrind.supp
+++ b/cpp/valgrind.supp
@@ -22,6 +22,12 @@
fun:*CastFunctor*BooleanType*
}
{
+ <llvm>:Conditional jump or move depends on uninitialised value(s)
+ Memcheck:Cond
+ ...
+ fun:*llvm*PassManager*
+}
+{
<re2>:Conditional jump or move depends on uninitialised value(s)
Memcheck:Cond
...
diff --git a/java/gandiva/pom.xml b/java/gandiva/pom.xml
index d365eb9..285ea86 100644
--- a/java/gandiva/pom.xml
+++ b/java/gandiva/pom.xml
@@ -29,7 +29,7 @@
<protobuf.version>2.5.0</protobuf.version>
<dep.guava.version>18.0</dep.guava.version>
<checkstyle.failOnViolation>true</checkstyle.failOnViolation>
- <gandiva.cpp.build.dir>../../cpp/debug</gandiva.cpp.build.dir>
+ <gandiva.cpp.build.dir>../../cpp/debug/debug</gandiva.cpp.build.dir>
</properties>
<dependencies>
<dependency>
@@ -68,6 +68,11 @@
<version>2.10</version>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>net.java.dev.jna</groupId>
+ <artifactId>jna</artifactId>
+ <version>4.5.0</version>
+ </dependency>
</dependencies>
<profiles>
<profile>
diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ConfigurationBuilder.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ConfigurationBuilder.java
index 96788b3..46deee9 100644
--- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ConfigurationBuilder.java
+++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ConfigurationBuilder.java
@@ -17,8 +17,6 @@
package org.apache.arrow.gandiva.evaluator;
-import org.apache.arrow.gandiva.exceptions.GandivaException;
-
/**
* Used to construct gandiva configuration objects.
*/
@@ -26,16 +24,6 @@ public class ConfigurationBuilder {
private String byteCodeFilePath = "";
- private static volatile long defaultConfiguration = 0L;
-
- /**
- * Ctor - ensure that gandiva is loaded.
- * @throws GandivaException - if library cannot be loaded.
- */
- public ConfigurationBuilder() throws GandivaException {
- JniWrapper.getInstance();
- }
-
public ConfigurationBuilder withByteCodeFilePath(final String byteCodeFilePath) {
this.byteCodeFilePath = byteCodeFilePath;
return this;
@@ -45,26 +33,6 @@ public class ConfigurationBuilder {
return byteCodeFilePath;
}
- /**
- * Get the default configuration to invoke gandiva.
- * @return default configuration
- * @throws GandivaException if unable to get native builder instance.
- */
- static long getDefaultConfiguration() throws GandivaException {
- if (defaultConfiguration == 0L) {
- synchronized (ConfigurationBuilder.class) {
- if (defaultConfiguration == 0L) {
- String defaultByteCodeFilePath = JniWrapper.getInstance().getByteCodeFilePath();
-
- defaultConfiguration = new ConfigurationBuilder()
- .withByteCodeFilePath(defaultByteCodeFilePath)
- .buildConfigInstance();
- }
- }
- }
- return defaultConfiguration;
- }
-
public native long buildConfigInstance();
public native void releaseConfigInstance(long configId);
diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtil.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtil.java
new file mode 100644
index 0000000..37dd0f6
--- /dev/null
+++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtil.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import org.apache.arrow.vector.types.Types;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.ArrowType.Decimal;
+
+public class DecimalTypeUtil {
+
+ public enum OperationType {
+ ADD,
+ SUBTRACT,
+ MULTIPLY,
+ DIVIDE,
+ MOD
+ }
+
+ private static final int MIN_ADJUSTED_SCALE = 6;
+ /// The maximum precision representable by a 16-byte decimal
+ private static final int MAX_PRECISION = 38;
+
+ public static Decimal getResultTypeForOperation(OperationType operation, Decimal operand1, Decimal
+ operand2) {
+ int s1 = operand1.getScale();
+ int s2 = operand2.getScale();
+ int p1 = operand1.getPrecision();
+ int p2 = operand2.getPrecision();
+ int resultScale = 0;
+ int resultPrecision = 0;
+ switch (operation) {
+ case ADD:
+ case SUBTRACT:
+ resultScale = Math.max(operand1.getScale(), operand2.getScale());
+ resultPrecision = resultScale + Math.max(operand1.getPrecision() - operand1.getScale(),
+ operand2.getPrecision() - operand2.getScale()) + 1;
+ break;
+ case MULTIPLY:
+ resultScale = s1 + s2;
+ resultPrecision = p1 + p2 + 1;
+ break;
+ case DIVIDE:
+ resultScale =
+ Math.max(MIN_ADJUSTED_SCALE, operand1.getScale() + operand2.getPrecision() + 1);
+ resultPrecision =
+ operand1.getPrecision() - operand1.getScale() + operand2.getScale() + resultScale;
+ break;
+ case MOD:
+ resultScale = Math.max(operand1.getScale(), operand2.getScale());
+ resultPrecision = Math.min(operand1.getPrecision() - operand1.getScale(),
+ operand2.getPrecision() - operand2.getScale()) +
+ resultScale;
+ break;
+ default:
+ throw new RuntimeException("Needs support");
+ }
+ return adjustScaleIfNeeded(resultPrecision, resultScale);
+ }
+
+ private static Decimal adjustScaleIfNeeded(int precision, int scale) {
+ if (precision > MAX_PRECISION) {
+ int minScale = Math.min(scale, MIN_ADJUSTED_SCALE);
+ int delta = precision - MAX_PRECISION;
+ precision = MAX_PRECISION;
+ scale = Math.max(scale - delta, minScale);
+ }
+ return new Decimal(precision, scale);
+ }
+
+}
+
diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java
index 9c41c19..b998679 100644
--- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java
+++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java
@@ -70,7 +70,7 @@ public class ExpressionRegistry {
synchronized (ExpressionRegistry.class) {
if (INSTANCE == null) {
// ensure library is setup.
- JniWrapper.getInstance();
+ JniLoader.getInstance();
Set<ArrowType> typesFromGandiva = getSupportedTypesFromGandiva();
Set<FunctionSignature> functionsFromGandiva = getSupportedFunctionsFromGandiva();
INSTANCE = new ExpressionRegistry(typesFromGandiva, functionsFromGandiva);
@@ -173,10 +173,11 @@ public class ExpressionRegistry {
BIT_WIDTH_64);
case GandivaType.NONE_VALUE:
return new ArrowType.Null();
+ case GandivaType.DECIMAL_VALUE:
+ return new ArrowType.Decimal(0,0);
case GandivaType.FIXED_SIZE_BINARY_VALUE:
case GandivaType.MAP_VALUE:
case GandivaType.INTERVAL_VALUE:
- case GandivaType.DECIMAL_VALUE:
case GandivaType.DICTIONARY_VALUE:
case GandivaType.LIST_VALUE:
case GandivaType.STRUCT_VALUE:
diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Filter.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Filter.java
index 25904d3..46508b1 100644
--- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Filter.java
+++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Filter.java
@@ -43,11 +43,13 @@ public class Filter {
private static final Logger logger = LoggerFactory.getLogger(Filter.class);
+ private final JniWrapper wrapper;
private final long moduleId;
private final Schema schema;
private boolean closed;
- private Filter(long moduleId, Schema schema) {
+ private Filter(JniWrapper wrapper, long moduleId, Schema schema) {
+ this.wrapper = wrapper;
this.moduleId = moduleId;
this.schema = schema;
this.closed = false;
@@ -63,7 +65,7 @@ public class Filter {
* @return A native filter object that can be used to invoke on a RecordBatch
*/
public static Filter make(Schema schema, Condition condition) throws GandivaException {
- return make(schema, condition, ConfigurationBuilder.getDefaultConfiguration());
+ return make(schema, condition, JniLoader.getDefaultConfiguration());
}
/**
@@ -81,11 +83,11 @@ public class Filter {
// Invoke the JNI layer to create the LLVM module representing the filter.
GandivaTypes.Condition conditionBuf = condition.toProtobuf();
GandivaTypes.Schema schemaBuf = ArrowTypeHelper.arrowSchemaToProtobuf(schema);
- JniWrapper gandivaBridge = JniWrapper.getInstance();
- long moduleId = gandivaBridge.buildFilter(schemaBuf.toByteArray(),
+ JniWrapper wrapper = JniLoader.getInstance().getWrapper();
+ long moduleId = wrapper.buildFilter(schemaBuf.toByteArray(),
conditionBuf.toByteArray(), configurationId);
logger.info("Created module for the projector with id {}", moduleId);
- return new Filter(moduleId, schema);
+ return new Filter(wrapper, moduleId, schema);
}
/**
@@ -144,7 +146,7 @@ public class Filter {
bufSizes[idx++] = bufLayout.getSize();
}
- int numRecords = JniWrapper.getInstance().evaluateFilter(this.moduleId, numRows,
+ int numRecords = wrapper.evaluateFilter(this.moduleId, numRows,
bufAddrs, bufSizes,
selectionVector.getType().getNumber(),
selectionVector.getBuffer().memoryAddress(), selectionVector.getBuffer().capacity());
@@ -161,7 +163,7 @@ public class Filter {
return;
}
- JniWrapper.getInstance().closeFilter(this.moduleId);
+ wrapper.closeFilter(this.moduleId);
this.closed = true;
}
}
diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniLoader.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniLoader.java
new file mode 100644
index 0000000..3491b28
--- /dev/null
+++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniLoader.java
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import static java.util.UUID.randomUUID;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.file.Files;
+import java.nio.file.StandardCopyOption;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+
+import com.sun.jna.NativeLibrary;
+
+/**
+ * This class handles loading of the jni library, and acts as a bridge for the native functions.
+ */
+class JniLoader {
+ private static final String LIBRARY_NAME = "gandiva_jni";
+ private static final String IRHELPERS_BC = "irhelpers.bc";
+
+ private static volatile JniLoader INSTANCE;
+ private static volatile long defaultConfiguration = 0L;
+
+ private final String byteCodeFilePath;
+ private final JniWrapper wrapper;
+
+ private JniLoader(String byteCodeFilePath) {
+ this.byteCodeFilePath = byteCodeFilePath;
+ this.wrapper = new JniWrapper();
+ }
+
+ static JniLoader getInstance() throws GandivaException {
+ if (INSTANCE == null) {
+ synchronized (JniLoader.class) {
+ if (INSTANCE == null) {
+ INSTANCE = setupInstance();
+ }
+ }
+ }
+ return INSTANCE;
+ }
+
+ private static JniLoader setupInstance() throws GandivaException {
+ try {
+ String tempDir = System.getProperty("java.io.tmpdir");
+ loadGandivaLibraryFromJar(tempDir);
+ File byteCodeFile = moveFileFromJarToTemp(tempDir, IRHELPERS_BC);
+ return new JniLoader(byteCodeFile.getAbsolutePath());
+ } catch (IOException ioException) {
+ throw new GandivaException("unable to create native instance", ioException);
+ }
+ }
+
+ private static void loadGandivaLibraryFromJar(final String tmpDir)
+ throws IOException, GandivaException {
+ final String libraryToLoad = System.mapLibraryName(LIBRARY_NAME);
+ final File libraryFile = moveFileFromJarToTemp(tmpDir, libraryToLoad);
+ // This is required to load the library with RT_GLOBAL flags. Otherwise, the symbols in the
+ // libgandiva.so aren't visible to the JIT.
+ NativeLibrary.getInstance(libraryFile.getAbsolutePath());
+ System.load(libraryFile.getAbsolutePath());
+ }
+
+
+ private static File moveFileFromJarToTemp(final String tmpDir, String libraryToLoad)
+ throws IOException, GandivaException {
+ final File temp = setupFile(tmpDir, libraryToLoad);
+ try (final InputStream is = JniLoader.class.getClassLoader()
+ .getResourceAsStream(libraryToLoad)) {
+ if (is == null) {
+ throw new GandivaException(libraryToLoad + " was not found inside JAR.");
+ } else {
+ Files.copy(is, temp.toPath(), StandardCopyOption.REPLACE_EXISTING);
+ }
+ }
+ return temp;
+ }
+
+ private static File setupFile(String tmpDir, String libraryToLoad)
+ throws IOException, GandivaException {
+ // accommodate multiple processes running with gandiva jar.
+ // length should be ok since uuid is only 36 characters.
+ final String randomizeFileName = libraryToLoad + randomUUID();
+ final File temp = new File(tmpDir, randomizeFileName);
+ if (temp.exists() && !temp.delete()) {
+ throw new GandivaException("File: " + temp.getAbsolutePath() +
+ " already exists and cannot be removed.");
+ }
+ if (!temp.createNewFile()) {
+ throw new GandivaException("File: " + temp.getAbsolutePath() +
+ " could not be created.");
+ }
+ temp.deleteOnExit();
+ return temp;
+ }
+
+ /**
+ * Returns the byte code file path extracted from jar.
+ */
+ public String getByteCodeFilePath() {
+ return byteCodeFilePath;
+ }
+
+ /**
+ * Returns the jni wrapper.
+ */
+ JniWrapper getWrapper() throws GandivaException {
+ return wrapper;
+ }
+
+ /**
+ * Get the default configuration to invoke gandiva.
+ * @return default configuration
+ * @throws GandivaException if unable to get native builder instance.
+ */
+ static long getDefaultConfiguration() throws GandivaException {
+ if (defaultConfiguration == 0L) {
+ synchronized (ConfigurationBuilder.class) {
+ if (defaultConfiguration == 0L) {
+ String defaultByteCodeFilePath = JniLoader.getInstance().getByteCodeFilePath();
+
+ defaultConfiguration = new ConfigurationBuilder()
+ .withByteCodeFilePath(defaultByteCodeFilePath)
+ .buildConfigInstance();
+ }
+ }
+ }
+ return defaultConfiguration;
+ }
+}
diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java
index eea42f6..f00b0fb 100644
--- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java
+++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java
@@ -17,100 +17,15 @@
package org.apache.arrow.gandiva.evaluator;
-import static java.util.UUID.randomUUID;
-
-import java.io.File;
-import java.io.IOException;
-import java.io.InputStream;
-import java.nio.file.Files;
-import java.nio.file.StandardCopyOption;
-
import org.apache.arrow.gandiva.exceptions.GandivaException;
/**
* This class is implemented in JNI. This provides the Java interface
- * to invoke functions in JNI
+ * to invoke functions in JNI.
+ * This file is used to generated the .h files required for jni. Avoid all
+ * external dependencies in this file.
*/
-class JniWrapper {
- private static final String LIBRARY_NAME = "gandiva_jni";
- private static final String IRHELPERS_BC = "irhelpers.bc";
-
- private static volatile JniWrapper INSTANCE;
-
- private final String byteCodeFilePath;
-
- private JniWrapper(String byteCodeFilePath) {
- this.byteCodeFilePath = byteCodeFilePath;
- }
-
- static JniWrapper getInstance() throws GandivaException {
- if (INSTANCE == null) {
- synchronized (JniWrapper.class) {
- if (INSTANCE == null) {
- INSTANCE = setupInstance();
- }
- }
- }
- return INSTANCE;
- }
-
- private static JniWrapper setupInstance() throws GandivaException {
- try {
- String tempDir = System.getProperty("java.io.tmpdir");
- loadGandivaLibraryFromJar(tempDir);
- File byteCodeFile = moveFileFromJarToTemp(tempDir, IRHELPERS_BC);
- return new JniWrapper(byteCodeFile.getAbsolutePath());
- } catch (IOException ioException) {
- throw new GandivaException("unable to create native instance", ioException);
- }
- }
-
- private static void loadGandivaLibraryFromJar(final String tmpDir)
- throws IOException, GandivaException {
- final String libraryToLoad = System.mapLibraryName(LIBRARY_NAME);
- final File libraryFile = moveFileFromJarToTemp(tmpDir, libraryToLoad);
- System.load(libraryFile.getAbsolutePath());
- }
-
-
- private static File moveFileFromJarToTemp(final String tmpDir, String libraryToLoad)
- throws IOException, GandivaException {
- final File temp = setupFile(tmpDir, libraryToLoad);
- try (final InputStream is = JniWrapper.class.getClassLoader()
- .getResourceAsStream(libraryToLoad)) {
- if (is == null) {
- throw new GandivaException(libraryToLoad + " was not found inside JAR.");
- } else {
- Files.copy(is, temp.toPath(), StandardCopyOption.REPLACE_EXISTING);
- }
- }
- return temp;
- }
-
- private static File setupFile(String tmpDir, String libraryToLoad)
- throws IOException, GandivaException {
- // accommodate multiple processes running with gandiva jar.
- // length should be ok since uuid is only 36 characters.
- final String randomizeFileName = libraryToLoad + randomUUID();
- final File temp = new File(tmpDir, randomizeFileName);
- if (temp.exists() && !temp.delete()) {
- throw new GandivaException("File: " + temp.getAbsolutePath() +
- " already exists and cannot be removed.");
- }
- if (!temp.createNewFile()) {
- throw new GandivaException("File: " + temp.getAbsolutePath() +
- " could not be created.");
- }
- temp.deleteOnExit();
- return temp;
- }
-
- /**
- * Returns the byte code file path extracted from jar.
- */
- public String getByteCodeFilePath() {
- return byteCodeFilePath;
- }
+public class JniWrapper {
/**
* Generates the projector module to evaluate the expressions with
diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java
index d757893..af1a4ca 100644
--- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java
+++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java
@@ -46,12 +46,14 @@ public class Projector {
private static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(Projector.class);
+ private JniWrapper wrapper;
private final long moduleId;
private final Schema schema;
private final int numExprs;
private boolean closed;
- private Projector(long moduleId, Schema schema, int numExprs) {
+ private Projector(JniWrapper wrapper, long moduleId, Schema schema, int numExprs) {
+ this.wrapper = wrapper;
this.moduleId = moduleId;
this.schema = schema;
this.numExprs = numExprs;
@@ -71,7 +73,7 @@ public class Projector {
*/
public static Projector make(Schema schema, List<ExpressionTree> exprs)
throws GandivaException {
- return make(schema, exprs, ConfigurationBuilder.getDefaultConfiguration());
+ return make(schema, exprs, JniLoader.getDefaultConfiguration());
}
/**
@@ -96,11 +98,11 @@ public class Projector {
// Invoke the JNI layer to create the LLVM module representing the expressions
GandivaTypes.Schema schemaBuf = ArrowTypeHelper.arrowSchemaToProtobuf(schema);
- JniWrapper gandivaBridge = JniWrapper.getInstance();
- long moduleId = gandivaBridge.buildProjector(schemaBuf.toByteArray(), builder.build()
- .toByteArray(), configurationId);
+ JniWrapper wrapper = JniLoader.getInstance().getWrapper();
+ long moduleId = wrapper.buildProjector(schemaBuf.toByteArray(),
+ builder.build().toByteArray(), configurationId);
logger.info("Created module for the projector with id {}", moduleId);
- return new Projector(moduleId, schema, exprs.size());
+ return new Projector(wrapper, moduleId, schema, exprs.size());
}
/**
@@ -175,9 +177,7 @@ public class Projector {
valueVector.setValueCount(numRows);
}
- JniWrapper.getInstance().evaluateProjector(this.moduleId, numRows,
- bufAddrs, bufSizes,
- outAddrs, outSizes);
+ wrapper.evaluateProjector(this.moduleId, numRows, bufAddrs, bufSizes, outAddrs, outSizes);
}
/**
@@ -188,7 +188,7 @@ public class Projector {
return;
}
- JniWrapper.getInstance().closeProjector(this.moduleId);
+ wrapper.closeProjector(this.moduleId);
this.closed = true;
}
}
diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/DecimalNode.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/DecimalNode.java
new file mode 100644
index 0000000..1b908b9
--- /dev/null
+++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/DecimalNode.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import java.nio.charset.Charset;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+
+import com.google.protobuf.ByteString;
+
+
+/**
+ * Used to represent expression tree nodes representing decimal constants.
+ * Used in the expression (x + 5.0)
+ */
+class DecimalNode implements TreeNode {
+ private final String value;
+ private final int precision;
+ private final int scale;
+
+ DecimalNode(String value, int precision, int scale) {
+ this.value = value;
+ this.precision = precision;
+ this.scale = scale;
+ }
+
+ @Override
+ public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
+ GandivaTypes.DecimalNode.Builder decimalNode = GandivaTypes.DecimalNode.newBuilder();
+ decimalNode.setValue(value);
+ decimalNode.setPrecision(precision);
+ decimalNode.setScale(scale);
+
+ GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder();
+ builder.setDecimalNode(decimalNode.build());
+ return builder.build();
+ }
+}
diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
index f556859..a220c54 100644
--- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
+++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
@@ -55,6 +55,10 @@ public class TreeBuilder {
return new BinaryNode(binaryConstant);
}
+ public static TreeNode makeDecimalLiteral(String decimalConstant, int precision, int scale) {
+ return new DecimalNode(decimalConstant, precision, scale);
+ }
+
/**
* create a null literal.
*/
diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java
index aeb3d41..97c2883 100644
--- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java
+++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java
@@ -17,6 +17,8 @@
package org.apache.arrow.gandiva.evaluator;
+import java.math.BigDecimal;
+import java.math.BigInteger;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
@@ -27,6 +29,7 @@ import org.apache.arrow.gandiva.expression.Condition;
import org.apache.arrow.gandiva.expression.ExpressionTree;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
@@ -229,6 +232,18 @@ class BaseEvaluatorTest {
return buffer;
}
+ DecimalVector decimalVector(String[] values, int precision, int scale) {
+ DecimalVector vector = new DecimalVector("decimal" + Math.random(), allocator, precision, scale);
+ vector.allocateNew();
+ for (int i = 0; i < values.length; i++) {
+ BigDecimal decimal = new BigDecimal(values[i]);
+ vector.setSafe(i, decimal);
+ }
+
+ vector.setValueCount(values.length);
+ return vector;
+ }
+
ArrowBuf longBuf(long[] longs) {
ArrowBuf buffer = allocator.buffer(longs.length * 8);
for (int i = 0; i < longs.length; i++) {
diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtilTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtilTest.java
new file mode 100644
index 0000000..4a4fb82
--- /dev/null
+++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtilTest.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class DecimalTypeUtilTest {
+
+ @Test
+ public void testOutputTypesForAdd() {
+ ArrowType.Decimal operand1 = getDecimal(30, 10);
+ ArrowType.Decimal operand2 = getDecimal(30, 10);
+ ArrowType.Decimal resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.ADD, operand1, operand2);
+ Assert.assertTrue(getDecimal(31, 10).equals(resultType));
+
+ operand1 = getDecimal(30, 6);
+ operand2 = getDecimal(30, 5);
+ resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.ADD, operand1, operand2);
+ Assert.assertTrue(getDecimal(32, 6).equals(resultType));
+
+ operand1 = getDecimal(30, 10);
+ operand2 = getDecimal(38, 10);
+ resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.ADD, operand1, operand2);
+ Assert.assertTrue(getDecimal(38, 9).equals(resultType));
+
+ operand1 = getDecimal(38, 10);
+ operand2 = getDecimal(38, 38);
+ resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.ADD, operand1, operand2);
+ Assert.assertTrue(getDecimal(38, 9).equals(resultType));
+
+ operand1 = getDecimal(38, 10);
+ operand2 = getDecimal(38, 2);
+ resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.ADD, operand1, operand2);
+ Assert.assertTrue(getDecimal(38, 6).equals(resultType));
+
+ }
+
+ @Test
+ public void testOutputTypesForMultiply() {
+ ArrowType.Decimal operand1 = getDecimal(30, 10);
+ ArrowType.Decimal operand2 = getDecimal(30, 10);
+ ArrowType.Decimal resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.MULTIPLY, operand1, operand2);
+ Assert.assertTrue(getDecimal(38, 6).equals(resultType));
+
+ operand1 = getDecimal(38, 10);
+ operand2 = getDecimal(9, 2);
+ resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.MULTIPLY, operand1, operand2);
+ Assert.assertTrue(getDecimal(38, 6).equals(resultType));
+
+ }
+
+ @Test
+ public void testOutputTypesForMod() {
+ ArrowType.Decimal operand1 = getDecimal(30, 10);
+ ArrowType.Decimal operand2 = getDecimal(28 , 7);
+ ArrowType.Decimal resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.MOD, operand1, operand2);
+ Assert.assertTrue(getDecimal(30, 10).equals(resultType));
+ }
+
+ private ArrowType.Decimal getDecimal(int precision, int scale) {
+ return new ArrowType.Decimal(precision, scale);
+ }
+
+}
diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java
new file mode 100644
index 0000000..a3a0b48
--- /dev/null
+++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.expression.ExpressionTree;
+import org.apache.arrow.gandiva.expression.TreeBuilder;
+import org.apache.arrow.gandiva.expression.TreeNode;
+import org.apache.arrow.vector.DecimalVector;
+import org.apache.arrow.vector.ValueVector;
+import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+
+public class ProjectorDecimalTest extends org.apache.arrow.gandiva.evaluator.BaseEvaluatorTest {
+
+ @Test
+ public void test_add() throws GandivaException {
+ int precision = 38;
+ int scale = 8;
+ ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale);
+ Field a = Field.nullable("a", decimal);
+ Field b = Field.nullable("b", decimal);
+ List<Field> args = Lists.newArrayList(a, b);
+
+ ArrowType.Decimal outputType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.ADD, decimal, decimal);
+ Field retType = Field.nullable("c", outputType);
+ ExpressionTree root = TreeBuilder.makeExpression("add", args, retType);
+
+ List<ExpressionTree> exprs = Lists.newArrayList(root);
+
+ Schema schema = new Schema(args);
+ Projector eval = Projector.make(schema, exprs);
+
+ int numRows = 4;
+ byte[] validity = new byte[]{(byte) 255};
+ String[] aValues = new String[]{"1.12345678","2.12345678","3.12345678","4.12345678"};
+ String[] bValues = new String[]{"2.12345678","3.12345678","4.12345678","5.12345678"};
+
+ DecimalVector valuesa = decimalVector(aValues, precision, scale);
+ DecimalVector valuesb = decimalVector(bValues, precision, scale);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)),
+ Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer(),
+ valuesb.getValidityBuffer(), valuesb.getDataBuffer()));
+
+ DecimalVector outVector = new DecimalVector("decimal_output", allocator, outputType.getPrecision(),
+ outputType.getScale());
+ outVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(outVector);
+ eval.evaluate(batch, output);
+
+ // should have scaled down.
+ BigDecimal[] expOutput = new BigDecimal[]{BigDecimal.valueOf(3.2469136),
+ BigDecimal.valueOf(5.2469136),
+ BigDecimal.valueOf(7.2469136),
+ BigDecimal.valueOf(9.2469136)};
+
+ for (int i = 0; i < 4; i++) {
+ assertFalse(outVector.isNull(i));
+ assertTrue("index : " + i + " failed compare", expOutput[i].compareTo(outVector.getObject(i)
+ ) == 0);
+ }
+
+ // free buffers
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void test_add_literal() throws GandivaException {
+ int precision = 2;
+ int scale = 0;
+ ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale);
+ ArrowType.Decimal literalType = new ArrowType.Decimal(2, 1);
+ Field a = Field.nullable("a", decimal);
+
+ ArrowType.Decimal outputType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.ADD, decimal, literalType);
+ Field retType = Field.nullable("c", outputType);
+ TreeNode field = TreeBuilder.makeField(a);
+ TreeNode literal = TreeBuilder.makeDecimalLiteral("6", 2, 1);
+ List<TreeNode> args = Lists.newArrayList(field, literal);
+ TreeNode root = TreeBuilder.makeFunction("add", args, outputType);
+ ExpressionTree tree = TreeBuilder.makeExpression(root, retType);
+
+ List<ExpressionTree> exprs = Lists.newArrayList(tree);
+
+ Schema schema = new Schema(Lists.newArrayList(a));
+ Projector eval = Projector.make(schema, exprs);
+
+ int numRows = 4;
+ String[] aValues = new String[]{"1", "2", "3", "4"};
+
+ DecimalVector valuesa = decimalVector(aValues, precision, scale);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 0)),
+ Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer()));
+
+ DecimalVector outVector = new DecimalVector("decimal_output", allocator, outputType.getPrecision(),
+ outputType.getScale());
+ outVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(outVector);
+ eval.evaluate(batch, output);
+
+ BigDecimal[] expOutput = new BigDecimal[]{BigDecimal.valueOf(1.6), BigDecimal.valueOf(2.6),
+ BigDecimal.valueOf(3.6), BigDecimal.valueOf(4.6)};
+
+ for (int i = 0; i < 4; i++) {
+ assertFalse(outVector.isNull(i));
+ assertTrue(expOutput[i].compareTo(outVector.getObject(i)) == 0);
+ }
+
+ // free buffers
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+}
diff --git a/python/pyarrow/gandiva.pyx b/python/pyarrow/gandiva.pyx
index 76e55d6..715ff9d 100644
--- a/python/pyarrow/gandiva.pyx
+++ b/python/pyarrow/gandiva.pyx
@@ -19,6 +19,8 @@
# distutils: language = c++
# cython: embedsignature = True
+import os
+
from libcpp cimport bool as c_bool, nullptr
from libcpp.memory cimport shared_ptr, unique_ptr, make_shared
from libcpp.string cimport string as c_string
@@ -73,6 +75,14 @@ from pyarrow.includes.libgandiva cimport (
CFunctionSignature,
GetRegisteredFunctionSignatures)
+if os.name == 'posix':
+ # Expose self with RTLD_GLOBAL so that symbols from gandiva.so and child
+ # libs (such as libstdc++) can be reached during JIT code execution.
+ # Another workaround is to use
+ # sys.setdlopenflags(os.RTLD_GLOBAL | os.RTLD_NOW)
+ # but it would affect all C extensions loaded in the process.
+ import ctypes
+ _dll = ctypes.CDLL(__file__, ctypes.RTLD_GLOBAL)
cdef class Node:
cdef: