You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ra...@apache.org on 2019/07/08 06:15:46 UTC
[arrow] branch master updated: ARROW-5758: [C++][Gandiva][Java]
Support casting decimals to varchar and vice versa
This is an automated email from the ASF dual-hosted git repository.
ravindra pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 1f9238e ARROW-5758: [C++][Gandiva][Java] Support casting decimals to varchar and vice versa
1f9238e is described below
commit 1f9238e3ebb9b9347f7881396215093d0e585c66
Author: Prudhvi Porandla <pr...@icloud.com>
AuthorDate: Mon Jul 8 11:45:27 2019 +0530
ARROW-5758: [C++][Gandiva][Java] Support casting decimals to varchar and vice versa
support `castVARCHAR(decimal, out_str_len)`, `castDECIMAL(string, out_precision, out_scale)` functions in Gandiva
Author: Prudhvi Porandla <pr...@icloud.com>
Closes #4803 from pprudhvi/decimal-varchar and squashes the following commits:
72f479f95 <Prudhvi Porandla> assert status
a222ef28c <Prudhvi Porandla> set error in context if string is invalid decimal
b4eecbc61 <Prudhvi Porandla> lint
dd6ba7cc8 <Prudhvi Porandla> return "-" instead of "0" for negative numbers when out_len is 1
c013a1581 <Prudhvi Porandla> revert arrow version change
bf6106d8f <Prudhvi Porandla> lint
a6ac92b9e <Prudhvi Porandla> lint
19bfb5fac <Prudhvi Porandla> cast size_t to int32_t
c14dc74f9 <Prudhvi Porandla> temp fix - change arrow version to 0.x*
428e87029 <Prudhvi Porandla> lint
9d94c4972 <Prudhvi Porandla> do toString, FromString in gdv function stubs
727c5d72d <Prudhvi Porandla> implement cast decimal<->varchar in decimal wrapper
---
cpp/src/gandiva/decimal_ir.cc | 21 +++
cpp/src/gandiva/function_registry_arithmetic.cc | 1 +
cpp/src/gandiva/function_registry_string.cc | 4 +
cpp/src/gandiva/gdv_function_stubs.cc | 58 ++++++
cpp/src/gandiva/gdv_function_stubs.h | 7 +
cpp/src/gandiva/precompiled/decimal_wrapper.cc | 37 ++++
cpp/src/gandiva/tests/decimal_test.cc | 207 +++++++++++++++++++++
.../arrow/gandiva/evaluator/BaseEvaluatorTest.java | 12 ++
.../gandiva/evaluator/ProjectorDecimalTest.java | 127 +++++++++++++
9 files changed, 474 insertions(+)
diff --git a/cpp/src/gandiva/decimal_ir.cc b/cpp/src/gandiva/decimal_ir.cc
index 6e4bb56..463e448 100644
--- a/cpp/src/gandiva/decimal_ir.cc
+++ b/cpp/src/gandiva/decimal_ir.cc
@@ -560,6 +560,8 @@ Status DecimalIR::AddFunctions(Engine* engine) {
auto i1 = decimal_ir->types()->i1_type();
auto i64 = decimal_ir->types()->i64_type();
auto f64 = decimal_ir->types()->double_type();
+ auto i8_ptr = decimal_ir->types()->i8_ptr_type();
+ auto i32_ptr = decimal_ir->types()->i32_ptr_type();
// Populate global variables used by decimal operations.
decimal_ir->AddGlobals(engine);
@@ -822,6 +824,25 @@ Status DecimalIR::AddFunctions(Engine* engine) {
{"y_isvalid", i1},
}));
+ ARROW_RETURN_NOT_OK(decimal_ir->BuildDecimalFunction("castDECIMAL_utf8", i128,
+ {
+ {"context", i64},
+ {"in", i8_ptr},
+ {"in_len", i32},
+ {"out_precision", i32},
+ {"out_scale", i32},
+ }));
+ ARROW_RETURN_NOT_OK(decimal_ir->BuildDecimalFunction("castVARCHAR_decimal128_int64",
+ i8_ptr,
+ {
+ {"context", i64},
+ {"x_value", i128},
+ {"x_precision", i32},
+ {"x_scale", i32},
+ {"out_len_param", i64},
+ {"out_length", i32_ptr},
+ }));
+
return Status::OK();
}
diff --git a/cpp/src/gandiva/function_registry_arithmetic.cc b/cpp/src/gandiva/function_registry_arithmetic.cc
index b6c5819..f7f88bb 100644
--- a/cpp/src/gandiva/function_registry_arithmetic.cc
+++ b/cpp/src/gandiva/function_registry_arithmetic.cc
@@ -49,6 +49,7 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
UNARY_SAFE_NULL_IF_NULL(castDECIMAL, int64, decimal128),
UNARY_SAFE_NULL_IF_NULL(castDECIMAL, float64, decimal128),
UNARY_SAFE_NULL_IF_NULL(castDECIMAL, decimal128, decimal128),
+ UNARY_UNSAFE_NULL_IF_NULL(castDECIMAL, utf8, decimal128),
UNARY_SAFE_NULL_IF_NULL(castDATE, int64, date64),
diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc
index 19e31c8..af1f37f 100644
--- a/cpp/src/gandiva/function_registry_string.cc
+++ b/cpp/src/gandiva/function_registry_string.cc
@@ -61,6 +61,10 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
kResultNullIfNull, "castVARCHAR_utf8_int64",
NativeFunction::kNeedsContext),
+ NativeFunction("castVARCHAR", DataTypeVector{decimal128(), int64()}, utf8(),
+ kResultNullIfNull, "castVARCHAR_decimal128_int64",
+ NativeFunction::kNeedsContext),
+
NativeFunction("like", DataTypeVector{utf8(), utf8()}, boolean(), kResultNullIfNull,
"gdv_fn_like_utf8_utf8", NativeFunction::kNeedsFunctionHolder)};
diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc
index 570e026..1d8fe9b 100644
--- a/cpp/src/gandiva/gdv_function_stubs.cc
+++ b/cpp/src/gandiva/gdv_function_stubs.cc
@@ -97,6 +97,36 @@ int32_t gdv_fn_populate_varlen_vector(int64_t context_ptr, int8_t* data_ptr,
offsets[slot + 1] = offset + entry_len;
return 0;
}
+
+int32_t gdv_fn_dec_from_string(int64_t context, const char* in, int32_t in_length,
+ int32_t* precision_from_str, int32_t* scale_from_str,
+ int64_t* dec_high_from_str, uint64_t* dec_low_from_str) {
+ arrow::Decimal128 dec;
+ auto status = arrow::Decimal128::FromString(std::string(in, in_length), &dec,
+ precision_from_str, scale_from_str);
+ if (!status.ok()) {
+ gdv_fn_context_set_error_msg(context, status.message().data());
+ return -1;
+ }
+ *dec_high_from_str = dec.high_bits();
+ *dec_low_from_str = dec.low_bits();
+ return 0;
+}
+
+char* gdv_fn_dec_to_string(int64_t context, int64_t x_high, uint64_t x_low,
+ int32_t x_scale, int32_t* dec_str_len) {
+ arrow::Decimal128 dec(arrow::BasicDecimal128(x_high, x_low));
+ std::string dec_str = dec.ToString(x_scale);
+ *dec_str_len = static_cast<int32_t>(dec_str.length());
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *dec_str_len));
+ if (ret == NULLPTR) {
+ std::string err_msg = "Could not allocate memory for string: " + dec_str;
+ gdv_fn_context_set_error_msg(context, err_msg.data());
+ return NULLPTR;
+ }
+ memcpy(ret, dec_str.data(), *dec_str_len);
+ return ret;
+}
}
namespace gandiva {
@@ -105,6 +135,34 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const {
std::vector<llvm::Type*> args;
auto types = engine->types();
+ // gdv_fn_dec_from_string
+ args = {
+ types->i64_type(), // context
+ types->i8_ptr_type(), // const char* in
+ types->i32_type(), // int32_t in_length
+ types->i32_ptr_type(), // int32_t* precision_from_str
+ types->i32_ptr_type(), // int32_t* scale_from_str
+ types->i64_ptr_type(), // int64_t* dec_high_from_str
+ types->i64_ptr_type(), // int64_t* dec_low_from_str
+ };
+
+ engine->AddGlobalMappingForFunc("gdv_fn_dec_from_string",
+ types->i32_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_dec_from_string));
+
+ // gdv_fn_dec_to_string
+ args = {
+ types->i64_type(), // context
+ types->i64_type(), // int64_t x_high
+ types->i64_type(), // int64_t x_low
+ types->i32_type(), // int32_t x_scale
+ types->i64_ptr_type(), // int64_t* dec_str_len
+ };
+
+ engine->AddGlobalMappingForFunc("gdv_fn_dec_to_string",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_dec_to_string));
+
// gdv_fn_like_utf8_utf8
args = {types->i64_type(), // int64_t ptr
types->i8_ptr_type(), // const char* data
diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h
index 8f940ce..fcdf7d6 100644
--- a/cpp/src/gandiva/gdv_function_stubs.h
+++ b/cpp/src/gandiva/gdv_function_stubs.h
@@ -46,6 +46,13 @@ bool in_expr_lookup_utf8(int64_t ptr, const char* data, int data_len, bool in_va
int gdv_fn_time_with_zone(int* time_fields, const char* zone, int zone_len,
int64_t* ret_time);
+
+int32_t gdv_fn_dec_from_string(int64_t context, const char* in, int32_t in_length,
+ int32_t* precision_from_str, int32_t* scale_from_str,
+ int64_t* dec_high_from_str, uint64_t* dec_low_from_str);
+
+char* gdv_fn_dec_to_string(int64_t context, int64_t x_high, uint64_t x_low,
+ int32_t x_scale, int32_t* dec_str_len);
}
#endif // GDV_FUNCTION_STUBS_H
diff --git a/cpp/src/gandiva/precompiled/decimal_wrapper.cc b/cpp/src/gandiva/precompiled/decimal_wrapper.cc
index 630fe8b..620b443 100644
--- a/cpp/src/gandiva/precompiled/decimal_wrapper.cc
+++ b/cpp/src/gandiva/precompiled/decimal_wrapper.cc
@@ -358,4 +358,41 @@ boolean is_distinct_from_decimal128_decimal128_internal(
y_isvalid);
}
+FORCE_INLINE
+void castDECIMAL_utf8_internal(int64_t context, const char* in, int32_t in_length,
+ int32_t out_precision, int32_t out_scale,
+ int64_t* out_high, uint64_t* out_low) {
+ int64_t dec_high_from_str;
+ uint64_t dec_low_from_str;
+ int32_t precision_from_str;
+ int32_t scale_from_str;
+ int32_t status =
+ gdv_fn_dec_from_string(context, in, in_length, &precision_from_str, &scale_from_str,
+ &dec_high_from_str, &dec_low_from_str);
+ if (status != 0) {
+ return;
+ }
+
+ gandiva::BasicDecimalScalar128 x({dec_high_from_str, dec_low_from_str},
+ precision_from_str, scale_from_str);
+ bool overflow = false;
+ auto out = gandiva::decimalops::Convert(x, out_precision, out_scale, &overflow);
+ *out_high = out.high_bits();
+ *out_low = out.low_bits();
+}
+
+FORCE_INLINE
+char* castVARCHAR_decimal128_int64_internal(int64_t context, int64_t x_high,
+ uint64_t x_low, int32_t x_precision,
+ int32_t x_scale, int64_t out_len_param,
+ int32_t* out_length) {
+ int32_t full_dec_str_len;
+ char* dec_str =
+ gdv_fn_dec_to_string(context, x_high, x_low, x_scale, &full_dec_str_len);
+ int32_t trunc_dec_str_len =
+ out_len_param < full_dec_str_len ? out_len_param : full_dec_str_len;
+ *out_length = trunc_dec_str_len;
+ return dec_str;
+}
+
} // extern "C"
diff --git a/cpp/src/gandiva/tests/decimal_test.cc b/cpp/src/gandiva/tests/decimal_test.cc
index 9bb08e1..1762d5b 100644
--- a/cpp/src/gandiva/tests/decimal_test.cc
+++ b/cpp/src/gandiva/tests/decimal_test.cc
@@ -29,6 +29,7 @@
using arrow::boolean;
using arrow::Decimal128;
+using arrow::utf8;
namespace gandiva {
@@ -843,4 +844,210 @@ TEST_F(TestDecimal, TestNullDecimalConstant) {
EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
}
+TEST_F(TestDecimal, TestCastVarCharDecimal) {
+ // schema for input fields
+ constexpr int32_t precision = 38;
+ constexpr int32_t scale = 2;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+
+ auto field_dec = field("dec", decimal_type);
+ auto field_res_str = field("res_str", utf8());
+ auto field_res_str_1 = field("res_str_1", utf8());
+ auto schema = arrow::schema({field_dec, field_res_str, field_res_str_1});
+
+ // output fields
+ auto res_str = field("res_str", utf8());
+ auto equals_res_bool = field("equals_res", boolean());
+
+ // build expressions.
+ auto node_dec = TreeExprBuilder::MakeField(field_dec);
+ auto node_res_str = TreeExprBuilder::MakeField(field_res_str);
+ auto node_res_str_1 = TreeExprBuilder::MakeField(field_res_str_1);
+ // limits decimal string to input length
+ auto str_len_limit = TreeExprBuilder::MakeLiteral(static_cast<int64_t>(5));
+ auto str_len_limit_1 = TreeExprBuilder::MakeLiteral(static_cast<int64_t>(1));
+ auto cast_varchar =
+ TreeExprBuilder::MakeFunction("castVARCHAR", {node_dec, str_len_limit}, utf8());
+ auto cast_varchar_1 =
+ TreeExprBuilder::MakeFunction("castVARCHAR", {node_dec, str_len_limit_1}, utf8());
+ auto equals =
+ TreeExprBuilder::MakeFunction("equal", {cast_varchar, node_res_str}, boolean());
+ auto equals_1 =
+ TreeExprBuilder::MakeFunction("equal", {cast_varchar_1, node_res_str_1}, boolean());
+ auto expr = TreeExprBuilder::MakeExpression(equals, equals_res_bool);
+ auto expr_1 = TreeExprBuilder::MakeExpression(equals_1, equals_res_bool);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+
+ auto status = Projector::Make(schema, {expr, expr_1}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array_dec = MakeArrowArrayDecimal(
+ decimal_type,
+ MakeDecimalVector({"10.51", "1.23", "100.23", "-1000.23", "-0000.10"}, scale),
+ {true, false, true, true, true});
+ auto array_str_res = MakeArrowArrayUtf8({"10.51", "-null-", "100.2", "-1000", "-0.10"},
+ {true, false, true, true, true});
+ auto array_str_res_1 =
+ MakeArrowArrayUtf8({"1", "-null-", "1", "-", "-"}, {true, false, true, true, true});
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records,
+ {array_dec, array_str_res, array_str_res_1});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ auto exp = MakeArrowArrayBool({true, false, true, true, true},
+ {true, false, true, true, true});
+ auto exp_1 = MakeArrowArrayBool({true, false, true, true, true},
+ {true, false, true, true, true});
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs[0]);
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs[1]);
+}
+
+TEST_F(TestDecimal, TestCastDecimalVarChar) {
+ // schema for input fields
+ constexpr int32_t precision = 4;
+ constexpr int32_t scale = 2;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+
+ auto field_str = field("in_str", utf8());
+ auto schema = arrow::schema({field_str});
+
+ // output fields
+ auto res_dec = field("res_dec", decimal_type);
+
+ // build expressions.
+ auto node_str = TreeExprBuilder::MakeField(field_str);
+ auto cast_decimal =
+ TreeExprBuilder::MakeFunction("castDECIMAL", {node_str}, decimal_type);
+ auto expr = TreeExprBuilder::MakeExpression(cast_decimal, res_dec);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+
+ auto array_str = MakeArrowArrayUtf8({"10.5134", "-0.0", "-0.1", "10.516", "-1000"},
+ {true, false, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_str});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ auto array_dec = MakeArrowArrayDecimal(
+ decimal_type, MakeDecimalVector({"10.51", "1.23", "-0.10", "10.52", "0.00"}, scale),
+ {true, false, true, true, true});
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(array_dec, outputs[0]);
+}
+
+TEST_F(TestDecimal, TestCastDecimalVarCharInvalidInput) {
+ // schema for input fields
+ constexpr int32_t precision = 38;
+ constexpr int32_t scale = 0;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+
+ auto field_str = field("in_str", utf8());
+ auto schema = arrow::schema({field_str});
+
+ // output fields
+ auto res_dec = field("res_dec", decimal_type);
+
+ // build expressions.
+ auto node_str = TreeExprBuilder::MakeField(field_str);
+ auto cast_decimal =
+ TreeExprBuilder::MakeFunction("castDECIMAL", {node_str}, decimal_type);
+ auto expr = TreeExprBuilder::MakeExpression(cast_decimal, res_dec);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+
+ // imvalid input
+ auto invalid_in = MakeArrowArrayUtf8({"a10.5134", "-0.0", "-0.1", "10.516", "-1000"},
+ {true, false, true, true, true});
+
+ // prepare input record batch
+ auto in_batch_1 = arrow::RecordBatch::Make(schema, num_records, {invalid_in});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs_1;
+ status = projector->Evaluate(*in_batch_1, pool_, &outputs_1);
+ EXPECT_FALSE(status.ok()) << status.message();
+ EXPECT_TRUE(status.message().find("not a valid decimal number") != std::string::npos);
+}
+
+TEST_F(TestDecimal, TestVarCharDecimalNestedCast) {
+ // schema for input fields
+ constexpr int32_t precision = 38;
+ constexpr int32_t scale = 2;
+ auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+
+ auto field_dec = field("dec", decimal_type);
+ auto schema = arrow::schema({field_dec});
+
+ // output fields
+ auto field_dec_res = field("dec_res", decimal_type);
+
+ // build expressions.
+ auto node_dec = TreeExprBuilder::MakeField(field_dec);
+
+ // limits decimal string to input length
+ auto str_len_limit = TreeExprBuilder::MakeLiteral(static_cast<int64_t>(5));
+ auto cast_varchar =
+ TreeExprBuilder::MakeFunction("castVARCHAR", {node_dec, str_len_limit}, utf8());
+ auto cast_decimal =
+ TreeExprBuilder::MakeFunction("castDECIMAL", {cast_varchar}, decimal_type);
+
+ auto expr = TreeExprBuilder::MakeExpression(cast_decimal, field_dec_res);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array_dec = MakeArrowArrayDecimal(
+ decimal_type,
+ MakeDecimalVector({"10.51", "1.23", "100.23", "-1000.23", "-0000.10"}, scale),
+ {true, false, true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_dec});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ auto array_dec_res = MakeArrowArrayDecimal(
+ decimal_type,
+ MakeDecimalVector({"10.51", "1.23", "100.20", "-1000.00", "-0.10"}, scale),
+ {true, false, true, true, true});
+ EXPECT_ARROW_ARRAY_EQUALS(array_dec_res, outputs[0]);
+}
+
} // namespace gandiva
diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java
index 9384cd4..c774b04 100644
--- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java
+++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java
@@ -32,6 +32,7 @@ import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.ValueVector;
+import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.FloatingPointPrecision;
@@ -243,6 +244,17 @@ class BaseEvaluatorTest {
return vector;
}
+ VarCharVector varcharVector(String[] values) {
+ VarCharVector vector = new VarCharVector("VarCharVector" + Math.random(), allocator);
+ vector.allocateNew();
+ for (int i = 0; i < values.length; i++) {
+ vector.setSafe(i, values[i].getBytes(), 0, values[i].length());
+ }
+
+ vector.setValueCount(values.length);
+ return vector;
+ }
+
ArrowBuf longBuf(long[] longs) {
ArrowBuf buffer = allocator.buffer(longs.length * 8);
for (int i = 0; i < longs.length; i++) {
diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java
index aaacffd..c5cb8f7 100644
--- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java
+++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java
@@ -35,6 +35,7 @@ import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.ValueVector;
+import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.pojo.ArrowType;
@@ -625,4 +626,130 @@ public class ProjectorDecimalTest extends org.apache.arrow.gandiva.evaluator.Bas
eval.close();
}
}
+
+ @Test
+ public void testCastToString() throws GandivaException {
+ Decimal decimalType = new Decimal(38, 2);
+ Field dec = Field.nullable("dec", decimalType);
+ Field str = Field.nullable("str", new ArrowType.Utf8());
+ TreeNode field = TreeBuilder.makeField(dec);
+ TreeNode literal = TreeBuilder.makeLiteral(5L);
+ List<TreeNode> args = Lists.newArrayList(field, literal);
+ TreeNode cast = TreeBuilder.makeFunction("castVARCHAR", args, new ArrowType.Utf8());
+ TreeNode root = TreeBuilder.makeFunction("equal",
+ Lists.newArrayList(cast, TreeBuilder.makeField(str)), new ArrowType.Bool());
+ ExpressionTree tree = TreeBuilder.makeExpression(root, Field.nullable("are_equal", new ArrowType.Bool()));
+
+ Schema schema = new Schema(Lists.newArrayList(dec, str));
+ Projector eval = Projector.make(schema, Lists.newArrayList(tree)
+ );
+
+ List<ValueVector> output = null;
+ ArrowRecordBatch batch = null;
+ try {
+ int numRows = 4;
+ String[] aValues = new String[]{"10.51", "100.23", "-1000.23", "-0000.10"};
+ String[] expected = {"10.51", "100.2", "-1000", "-0.10"};
+ DecimalVector valuesa = decimalVector(aValues, decimalType.getPrecision(), decimalType.getScale());
+ VarCharVector result = varcharVector(expected);
+ batch = new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(
+ new ArrowFieldNode(numRows, 0)
+ ),
+ Lists.newArrayList(
+ valuesa.getValidityBuffer(),
+ valuesa.getDataBuffer(),
+ result.getValidityBuffer(),
+ result.getOffsetBuffer(),
+ result.getDataBuffer()
+ )
+ );
+
+ BitVector resultVector = new BitVector("res", allocator);
+ resultVector.allocateNew();
+ output = new ArrayList<>(Arrays.asList(resultVector));
+
+ // evaluate expressions.
+ eval.evaluate(batch, output);
+
+ // compare the outputs.
+ for (int i = 0; i < numRows; i++) {
+ assertTrue(resultVector.getObject(i).booleanValue());
+ }
+ } finally {
+ // free buffers
+ if (batch != null) {
+ releaseRecordBatch(batch);
+ }
+ if (output != null) {
+ releaseValueVectors(output);
+ }
+ eval.close();
+ }
+ }
+
+ @Test
+ public void testCastStringToDecimal() throws GandivaException {
+ Decimal decimalType = new Decimal(4, 2);
+ Field dec = Field.nullable("dec", decimalType);
+
+ Field str = Field.nullable("str", new ArrowType.Utf8());
+ TreeNode field = TreeBuilder.makeField(str);
+ List<TreeNode> args = Lists.newArrayList(field);
+ TreeNode cast = TreeBuilder.makeFunction("castDECIMAL", args, decimalType);
+ ExpressionTree tree = TreeBuilder.makeExpression(cast, Field.nullable("dec_str", decimalType));
+
+ Schema schema = new Schema(Lists.newArrayList(str));
+ Projector eval = Projector.make(schema, Lists.newArrayList(tree)
+ );
+
+ List<ValueVector> output = null;
+ ArrowRecordBatch batch = null;
+ try {
+ int numRows = 4;
+ String[] aValues = new String[]{"10.5134", "-0.1", "10.516", "-1000"};
+ VarCharVector valuesa = varcharVector(aValues);
+ batch = new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(
+ new ArrowFieldNode(numRows, 0)
+ ),
+ Lists.newArrayList(
+ valuesa.getValidityBuffer(),
+ valuesa.getOffsetBuffer(),
+ valuesa.getDataBuffer()
+ )
+ );
+
+ DecimalVector resultVector = new DecimalVector("res", allocator,
+ decimalType.getPrecision(), decimalType.getScale());
+ resultVector.allocateNew();
+ output = new ArrayList<>(Arrays.asList(resultVector));
+
+ BigDecimal[] expected = {BigDecimal.valueOf(10.51), BigDecimal.valueOf(-0.10),
+ BigDecimal.valueOf(10.52), BigDecimal.valueOf(0.00)};
+ // evaluate expressions.
+ eval.evaluate(batch, output);
+
+ // compare the outputs.
+ for (int i = 0; i < numRows; i++) {
+ assertTrue("mismatch in result for " +
+ "field " + resultVector.getField().getName() +
+ " for row " + i +
+ " expected " + expected[i] +
+ ", got " + resultVector.getObject(i),expected[i].compareTo(resultVector.getObject(i)) == 0);
+ }
+ } finally {
+ // free buffers
+ if (batch != null) {
+ releaseRecordBatch(batch);
+ }
+ if (output != null) {
+ releaseValueVectors(output);
+ }
+ eval.close();
+ }
+ }
}
+