You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ra...@apache.org on 2019/07/08 06:15:46 UTC

[arrow] branch master updated: ARROW-5758: [C++][Gandiva][Java] Support casting decimals to varchar and vice versa

This is an automated email from the ASF dual-hosted git repository.

ravindra pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 1f9238e  ARROW-5758: [C++][Gandiva][Java] Support casting decimals to varchar and vice versa
1f9238e is described below

commit 1f9238e3ebb9b9347f7881396215093d0e585c66
Author: Prudhvi Porandla <pr...@icloud.com>
AuthorDate: Mon Jul 8 11:45:27 2019 +0530

    ARROW-5758: [C++][Gandiva][Java] Support casting decimals to varchar and vice versa
    
    support `castVARCHAR(decimal, out_str_len)`, `castDECIMAL(string, out_precision, out_scale)` functions in Gandiva
    
    Author: Prudhvi Porandla <pr...@icloud.com>
    
    Closes #4803 from pprudhvi/decimal-varchar and squashes the following commits:
    
    72f479f95 <Prudhvi Porandla> assert status
    a222ef28c <Prudhvi Porandla> set error in context if string is invalid decimal
    b4eecbc61 <Prudhvi Porandla> lint
    dd6ba7cc8 <Prudhvi Porandla> return "-" instead of "0" for negative numbers when out_len is 1
    c013a1581 <Prudhvi Porandla> revert arrow version change
    bf6106d8f <Prudhvi Porandla> lint
    a6ac92b9e <Prudhvi Porandla> lint
    19bfb5fac <Prudhvi Porandla> cast size_t to int32_t
    c14dc74f9 <Prudhvi Porandla> temp fix - change arrow version to 0.x*
    428e87029 <Prudhvi Porandla> lint
    9d94c4972 <Prudhvi Porandla> do toString, FromString in gdv function stubs
    727c5d72d <Prudhvi Porandla> implement cast decimal<->varchar in decimal wrapper
---
 cpp/src/gandiva/decimal_ir.cc                      |  21 +++
 cpp/src/gandiva/function_registry_arithmetic.cc    |   1 +
 cpp/src/gandiva/function_registry_string.cc        |   4 +
 cpp/src/gandiva/gdv_function_stubs.cc              |  58 ++++++
 cpp/src/gandiva/gdv_function_stubs.h               |   7 +
 cpp/src/gandiva/precompiled/decimal_wrapper.cc     |  37 ++++
 cpp/src/gandiva/tests/decimal_test.cc              | 207 +++++++++++++++++++++
 .../arrow/gandiva/evaluator/BaseEvaluatorTest.java |  12 ++
 .../gandiva/evaluator/ProjectorDecimalTest.java    | 127 +++++++++++++
 9 files changed, 474 insertions(+)

diff --git a/cpp/src/gandiva/decimal_ir.cc b/cpp/src/gandiva/decimal_ir.cc
index 6e4bb56..463e448 100644
--- a/cpp/src/gandiva/decimal_ir.cc
+++ b/cpp/src/gandiva/decimal_ir.cc
@@ -560,6 +560,8 @@ Status DecimalIR::AddFunctions(Engine* engine) {
   auto i1 = decimal_ir->types()->i1_type();
   auto i64 = decimal_ir->types()->i64_type();
   auto f64 = decimal_ir->types()->double_type();
+  auto i8_ptr = decimal_ir->types()->i8_ptr_type();
+  auto i32_ptr = decimal_ir->types()->i32_ptr_type();
 
   // Populate global variables used by decimal operations.
   decimal_ir->AddGlobals(engine);
@@ -822,6 +824,25 @@ Status DecimalIR::AddFunctions(Engine* engine) {
                                            {"y_isvalid", i1},
                                        }));
 
+  ARROW_RETURN_NOT_OK(decimal_ir->BuildDecimalFunction("castDECIMAL_utf8", i128,
+                                                       {
+                                                           {"context", i64},
+                                                           {"in", i8_ptr},
+                                                           {"in_len", i32},
+                                                           {"out_precision", i32},
+                                                           {"out_scale", i32},
+                                                       }));
+  ARROW_RETURN_NOT_OK(decimal_ir->BuildDecimalFunction("castVARCHAR_decimal128_int64",
+                                                       i8_ptr,
+                                                       {
+                                                           {"context", i64},
+                                                           {"x_value", i128},
+                                                           {"x_precision", i32},
+                                                           {"x_scale", i32},
+                                                           {"out_len_param", i64},
+                                                           {"out_length", i32_ptr},
+                                                       }));
+
   return Status::OK();
 }
 
diff --git a/cpp/src/gandiva/function_registry_arithmetic.cc b/cpp/src/gandiva/function_registry_arithmetic.cc
index b6c5819..f7f88bb 100644
--- a/cpp/src/gandiva/function_registry_arithmetic.cc
+++ b/cpp/src/gandiva/function_registry_arithmetic.cc
@@ -49,6 +49,7 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
       UNARY_SAFE_NULL_IF_NULL(castDECIMAL, int64, decimal128),
       UNARY_SAFE_NULL_IF_NULL(castDECIMAL, float64, decimal128),
       UNARY_SAFE_NULL_IF_NULL(castDECIMAL, decimal128, decimal128),
+      UNARY_UNSAFE_NULL_IF_NULL(castDECIMAL, utf8, decimal128),
 
       UNARY_SAFE_NULL_IF_NULL(castDATE, int64, date64),
 
diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc
index 19e31c8..af1f37f 100644
--- a/cpp/src/gandiva/function_registry_string.cc
+++ b/cpp/src/gandiva/function_registry_string.cc
@@ -61,6 +61,10 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
                      kResultNullIfNull, "castVARCHAR_utf8_int64",
                      NativeFunction::kNeedsContext),
 
+      NativeFunction("castVARCHAR", DataTypeVector{decimal128(), int64()}, utf8(),
+                     kResultNullIfNull, "castVARCHAR_decimal128_int64",
+                     NativeFunction::kNeedsContext),
+
       NativeFunction("like", DataTypeVector{utf8(), utf8()}, boolean(), kResultNullIfNull,
                      "gdv_fn_like_utf8_utf8", NativeFunction::kNeedsFunctionHolder)};
 
diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc
index 570e026..1d8fe9b 100644
--- a/cpp/src/gandiva/gdv_function_stubs.cc
+++ b/cpp/src/gandiva/gdv_function_stubs.cc
@@ -97,6 +97,36 @@ int32_t gdv_fn_populate_varlen_vector(int64_t context_ptr, int8_t* data_ptr,
   offsets[slot + 1] = offset + entry_len;
   return 0;
 }
+
+int32_t gdv_fn_dec_from_string(int64_t context, const char* in, int32_t in_length,
+                               int32_t* precision_from_str, int32_t* scale_from_str,
+                               int64_t* dec_high_from_str, uint64_t* dec_low_from_str) {
+  arrow::Decimal128 dec;
+  auto status = arrow::Decimal128::FromString(std::string(in, in_length), &dec,
+                                              precision_from_str, scale_from_str);
+  if (!status.ok()) {
+    gdv_fn_context_set_error_msg(context, status.message().data());
+    return -1;
+  }
+  *dec_high_from_str = dec.high_bits();
+  *dec_low_from_str = dec.low_bits();
+  return 0;
+}
+
+char* gdv_fn_dec_to_string(int64_t context, int64_t x_high, uint64_t x_low,
+                           int32_t x_scale, int32_t* dec_str_len) {
+  arrow::Decimal128 dec(arrow::BasicDecimal128(x_high, x_low));
+  std::string dec_str = dec.ToString(x_scale);
+  *dec_str_len = static_cast<int32_t>(dec_str.length());
+  char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *dec_str_len));
+  if (ret == NULLPTR) {
+    std::string err_msg = "Could not allocate memory for string: " + dec_str;
+    gdv_fn_context_set_error_msg(context, err_msg.data());
+    return NULLPTR;
+  }
+  memcpy(ret, dec_str.data(), *dec_str_len);
+  return ret;
+}
 }
 
 namespace gandiva {
@@ -105,6 +135,34 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const {
   std::vector<llvm::Type*> args;
   auto types = engine->types();
 
+  // gdv_fn_dec_from_string
+  args = {
+      types->i64_type(),      // context
+      types->i8_ptr_type(),   // const char* in
+      types->i32_type(),      // int32_t in_length
+      types->i32_ptr_type(),  // int32_t* precision_from_str
+      types->i32_ptr_type(),  // int32_t* scale_from_str
+      types->i64_ptr_type(),  // int64_t* dec_high_from_str
+      types->i64_ptr_type(),  // int64_t* dec_low_from_str
+  };
+
+  engine->AddGlobalMappingForFunc("gdv_fn_dec_from_string",
+                                  types->i32_type() /*return_type*/, args,
+                                  reinterpret_cast<void*>(gdv_fn_dec_from_string));
+
+  // gdv_fn_dec_to_string
+  args = {
+      types->i64_type(),      // context
+      types->i64_type(),      // int64_t x_high
+      types->i64_type(),      // int64_t x_low
+      types->i32_type(),      // int32_t x_scale
+      types->i64_ptr_type(),  // int64_t* dec_str_len
+  };
+
+  engine->AddGlobalMappingForFunc("gdv_fn_dec_to_string",
+                                  types->i8_ptr_type() /*return_type*/, args,
+                                  reinterpret_cast<void*>(gdv_fn_dec_to_string));
+
   // gdv_fn_like_utf8_utf8
   args = {types->i64_type(),     // int64_t ptr
           types->i8_ptr_type(),  // const char* data
diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h
index 8f940ce..fcdf7d6 100644
--- a/cpp/src/gandiva/gdv_function_stubs.h
+++ b/cpp/src/gandiva/gdv_function_stubs.h
@@ -46,6 +46,13 @@ bool in_expr_lookup_utf8(int64_t ptr, const char* data, int data_len, bool in_va
 
 int gdv_fn_time_with_zone(int* time_fields, const char* zone, int zone_len,
                           int64_t* ret_time);
+
+int32_t gdv_fn_dec_from_string(int64_t context, const char* in, int32_t in_length,
+                               int32_t* precision_from_str, int32_t* scale_from_str,
+                               int64_t* dec_high_from_str, uint64_t* dec_low_from_str);
+
+char* gdv_fn_dec_to_string(int64_t context, int64_t x_high, uint64_t x_low,
+                           int32_t x_scale, int32_t* dec_str_len);
 }
 
 #endif  // GDV_FUNCTION_STUBS_H
diff --git a/cpp/src/gandiva/precompiled/decimal_wrapper.cc b/cpp/src/gandiva/precompiled/decimal_wrapper.cc
index 630fe8b..620b443 100644
--- a/cpp/src/gandiva/precompiled/decimal_wrapper.cc
+++ b/cpp/src/gandiva/precompiled/decimal_wrapper.cc
@@ -358,4 +358,41 @@ boolean is_distinct_from_decimal128_decimal128_internal(
       y_isvalid);
 }
 
+FORCE_INLINE
+void castDECIMAL_utf8_internal(int64_t context, const char* in, int32_t in_length,
+                               int32_t out_precision, int32_t out_scale,
+                               int64_t* out_high, uint64_t* out_low) {
+  int64_t dec_high_from_str;
+  uint64_t dec_low_from_str;
+  int32_t precision_from_str;
+  int32_t scale_from_str;
+  int32_t status =
+      gdv_fn_dec_from_string(context, in, in_length, &precision_from_str, &scale_from_str,
+                             &dec_high_from_str, &dec_low_from_str);
+  if (status != 0) {
+    return;
+  }
+
+  gandiva::BasicDecimalScalar128 x({dec_high_from_str, dec_low_from_str},
+                                   precision_from_str, scale_from_str);
+  bool overflow = false;
+  auto out = gandiva::decimalops::Convert(x, out_precision, out_scale, &overflow);
+  *out_high = out.high_bits();
+  *out_low = out.low_bits();
+}
+
+FORCE_INLINE
+char* castVARCHAR_decimal128_int64_internal(int64_t context, int64_t x_high,
+                                            uint64_t x_low, int32_t x_precision,
+                                            int32_t x_scale, int64_t out_len_param,
+                                            int32_t* out_length) {
+  int32_t full_dec_str_len;
+  char* dec_str =
+      gdv_fn_dec_to_string(context, x_high, x_low, x_scale, &full_dec_str_len);
+  int32_t trunc_dec_str_len =
+      out_len_param < full_dec_str_len ? out_len_param : full_dec_str_len;
+  *out_length = trunc_dec_str_len;
+  return dec_str;
+}
+
 }  // extern "C"
diff --git a/cpp/src/gandiva/tests/decimal_test.cc b/cpp/src/gandiva/tests/decimal_test.cc
index 9bb08e1..1762d5b 100644
--- a/cpp/src/gandiva/tests/decimal_test.cc
+++ b/cpp/src/gandiva/tests/decimal_test.cc
@@ -29,6 +29,7 @@
 
 using arrow::boolean;
 using arrow::Decimal128;
+using arrow::utf8;
 
 namespace gandiva {
 
@@ -843,4 +844,210 @@ TEST_F(TestDecimal, TestNullDecimalConstant) {
   EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
 }
 
+TEST_F(TestDecimal, TestCastVarCharDecimal) {
+  // schema for input fields
+  constexpr int32_t precision = 38;
+  constexpr int32_t scale = 2;
+  auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+
+  auto field_dec = field("dec", decimal_type);
+  auto field_res_str = field("res_str", utf8());
+  auto field_res_str_1 = field("res_str_1", utf8());
+  auto schema = arrow::schema({field_dec, field_res_str, field_res_str_1});
+
+  // output fields
+  auto res_str = field("res_str", utf8());
+  auto equals_res_bool = field("equals_res", boolean());
+
+  // build expressions.
+  auto node_dec = TreeExprBuilder::MakeField(field_dec);
+  auto node_res_str = TreeExprBuilder::MakeField(field_res_str);
+  auto node_res_str_1 = TreeExprBuilder::MakeField(field_res_str_1);
+  // limits decimal string to input length
+  auto str_len_limit = TreeExprBuilder::MakeLiteral(static_cast<int64_t>(5));
+  auto str_len_limit_1 = TreeExprBuilder::MakeLiteral(static_cast<int64_t>(1));
+  auto cast_varchar =
+      TreeExprBuilder::MakeFunction("castVARCHAR", {node_dec, str_len_limit}, utf8());
+  auto cast_varchar_1 =
+      TreeExprBuilder::MakeFunction("castVARCHAR", {node_dec, str_len_limit_1}, utf8());
+  auto equals =
+      TreeExprBuilder::MakeFunction("equal", {cast_varchar, node_res_str}, boolean());
+  auto equals_1 =
+      TreeExprBuilder::MakeFunction("equal", {cast_varchar_1, node_res_str_1}, boolean());
+  auto expr = TreeExprBuilder::MakeExpression(equals, equals_res_bool);
+  auto expr_1 = TreeExprBuilder::MakeExpression(equals_1, equals_res_bool);
+
+  // Build a projector for the expressions.
+  std::shared_ptr<Projector> projector;
+
+  auto status = Projector::Make(schema, {expr, expr_1}, TestConfiguration(), &projector);
+  EXPECT_TRUE(status.ok()) << status.message();
+
+  // Create a row-batch with some sample data
+  int num_records = 5;
+  auto array_dec = MakeArrowArrayDecimal(
+      decimal_type,
+      MakeDecimalVector({"10.51", "1.23", "100.23", "-1000.23", "-0000.10"}, scale),
+      {true, false, true, true, true});
+  auto array_str_res = MakeArrowArrayUtf8({"10.51", "-null-", "100.2", "-1000", "-0.10"},
+                                          {true, false, true, true, true});
+  auto array_str_res_1 =
+      MakeArrowArrayUtf8({"1", "-null-", "1", "-", "-"}, {true, false, true, true, true});
+  // prepare input record batch
+  auto in_batch = arrow::RecordBatch::Make(schema, num_records,
+                                           {array_dec, array_str_res, array_str_res_1});
+
+  // Evaluate expression
+  arrow::ArrayVector outputs;
+  status = projector->Evaluate(*in_batch, pool_, &outputs);
+  EXPECT_TRUE(status.ok()) << status.message();
+
+  auto exp = MakeArrowArrayBool({true, false, true, true, true},
+                                {true, false, true, true, true});
+  auto exp_1 = MakeArrowArrayBool({true, false, true, true, true},
+                                  {true, false, true, true, true});
+  // Validate results
+  EXPECT_ARROW_ARRAY_EQUALS(exp, outputs[0]);
+  EXPECT_ARROW_ARRAY_EQUALS(exp, outputs[1]);
+}
+
+TEST_F(TestDecimal, TestCastDecimalVarChar) {
+  // schema for input fields
+  constexpr int32_t precision = 4;
+  constexpr int32_t scale = 2;
+  auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+
+  auto field_str = field("in_str", utf8());
+  auto schema = arrow::schema({field_str});
+
+  // output fields
+  auto res_dec = field("res_dec", decimal_type);
+
+  // build expressions.
+  auto node_str = TreeExprBuilder::MakeField(field_str);
+  auto cast_decimal =
+      TreeExprBuilder::MakeFunction("castDECIMAL", {node_str}, decimal_type);
+  auto expr = TreeExprBuilder::MakeExpression(cast_decimal, res_dec);
+
+  // Build a projector for the expressions.
+  std::shared_ptr<Projector> projector;
+
+  auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+  EXPECT_TRUE(status.ok()) << status.message();
+
+  // Create a row-batch with some sample data
+  int num_records = 5;
+
+  auto array_str = MakeArrowArrayUtf8({"10.5134", "-0.0", "-0.1", "10.516", "-1000"},
+                                      {true, false, true, true, true});
+
+  // prepare input record batch
+  auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_str});
+
+  // Evaluate expression
+  arrow::ArrayVector outputs;
+  status = projector->Evaluate(*in_batch, pool_, &outputs);
+  EXPECT_TRUE(status.ok()) << status.message();
+
+  auto array_dec = MakeArrowArrayDecimal(
+      decimal_type, MakeDecimalVector({"10.51", "1.23", "-0.10", "10.52", "0.00"}, scale),
+      {true, false, true, true, true});
+  // Validate results
+  EXPECT_ARROW_ARRAY_EQUALS(array_dec, outputs[0]);
+}
+
+TEST_F(TestDecimal, TestCastDecimalVarCharInvalidInput) {
+  // schema for input fields
+  constexpr int32_t precision = 38;
+  constexpr int32_t scale = 0;
+  auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+
+  auto field_str = field("in_str", utf8());
+  auto schema = arrow::schema({field_str});
+
+  // output fields
+  auto res_dec = field("res_dec", decimal_type);
+
+  // build expressions.
+  auto node_str = TreeExprBuilder::MakeField(field_str);
+  auto cast_decimal =
+      TreeExprBuilder::MakeFunction("castDECIMAL", {node_str}, decimal_type);
+  auto expr = TreeExprBuilder::MakeExpression(cast_decimal, res_dec);
+
+  // Build a projector for the expressions.
+  std::shared_ptr<Projector> projector;
+
+  auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+  EXPECT_TRUE(status.ok()) << status.message();
+
+  // Create a row-batch with some sample data
+  int num_records = 5;
+
+  // imvalid input
+  auto invalid_in = MakeArrowArrayUtf8({"a10.5134", "-0.0", "-0.1", "10.516", "-1000"},
+                                       {true, false, true, true, true});
+
+  // prepare input record batch
+  auto in_batch_1 = arrow::RecordBatch::Make(schema, num_records, {invalid_in});
+
+  // Evaluate expression
+  arrow::ArrayVector outputs_1;
+  status = projector->Evaluate(*in_batch_1, pool_, &outputs_1);
+  EXPECT_FALSE(status.ok()) << status.message();
+  EXPECT_TRUE(status.message().find("not a valid decimal number") != std::string::npos);
+}
+
+TEST_F(TestDecimal, TestVarCharDecimalNestedCast) {
+  // schema for input fields
+  constexpr int32_t precision = 38;
+  constexpr int32_t scale = 2;
+  auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale);
+
+  auto field_dec = field("dec", decimal_type);
+  auto schema = arrow::schema({field_dec});
+
+  // output fields
+  auto field_dec_res = field("dec_res", decimal_type);
+
+  // build expressions.
+  auto node_dec = TreeExprBuilder::MakeField(field_dec);
+
+  // limits decimal string to input length
+  auto str_len_limit = TreeExprBuilder::MakeLiteral(static_cast<int64_t>(5));
+  auto cast_varchar =
+      TreeExprBuilder::MakeFunction("castVARCHAR", {node_dec, str_len_limit}, utf8());
+  auto cast_decimal =
+      TreeExprBuilder::MakeFunction("castDECIMAL", {cast_varchar}, decimal_type);
+
+  auto expr = TreeExprBuilder::MakeExpression(cast_decimal, field_dec_res);
+
+  // Build a projector for the expressions.
+  std::shared_ptr<Projector> projector;
+
+  auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+  EXPECT_TRUE(status.ok()) << status.message();
+
+  // Create a row-batch with some sample data
+  int num_records = 5;
+  auto array_dec = MakeArrowArrayDecimal(
+      decimal_type,
+      MakeDecimalVector({"10.51", "1.23", "100.23", "-1000.23", "-0000.10"}, scale),
+      {true, false, true, true, true});
+
+  // prepare input record batch
+  auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_dec});
+
+  // Evaluate expression
+  arrow::ArrayVector outputs;
+  status = projector->Evaluate(*in_batch, pool_, &outputs);
+  EXPECT_TRUE(status.ok()) << status.message();
+
+  // Validate results
+  auto array_dec_res = MakeArrowArrayDecimal(
+      decimal_type,
+      MakeDecimalVector({"10.51", "1.23", "100.20", "-1000.00", "-0.10"}, scale),
+      {true, false, true, true, true});
+  EXPECT_ARROW_ARRAY_EQUALS(array_dec_res, outputs[0]);
+}
+
 }  // namespace gandiva
diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java
index 9384cd4..c774b04 100644
--- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java
+++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java
@@ -32,6 +32,7 @@ import org.apache.arrow.memory.RootAllocator;
 import org.apache.arrow.vector.DecimalVector;
 import org.apache.arrow.vector.IntVector;
 import org.apache.arrow.vector.ValueVector;
+import org.apache.arrow.vector.VarCharVector;
 import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
 import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
 import org.apache.arrow.vector.types.FloatingPointPrecision;
@@ -243,6 +244,17 @@ class BaseEvaluatorTest {
     return vector;
   }
 
+  VarCharVector varcharVector(String[] values) {
+    VarCharVector vector = new VarCharVector("VarCharVector" + Math.random(), allocator);
+    vector.allocateNew();
+    for (int i = 0; i < values.length; i++) {
+      vector.setSafe(i, values[i].getBytes(), 0, values[i].length());
+    }
+
+    vector.setValueCount(values.length);
+    return vector;
+  }
+
   ArrowBuf longBuf(long[] longs) {
     ArrowBuf buffer = allocator.buffer(longs.length * 8);
     for (int i = 0; i < longs.length; i++) {
diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java
index aaacffd..c5cb8f7 100644
--- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java
+++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java
@@ -35,6 +35,7 @@ import org.apache.arrow.vector.BitVector;
 import org.apache.arrow.vector.DecimalVector;
 import org.apache.arrow.vector.Float8Vector;
 import org.apache.arrow.vector.ValueVector;
+import org.apache.arrow.vector.VarCharVector;
 import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
 import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
 import org.apache.arrow.vector.types.pojo.ArrowType;
@@ -625,4 +626,130 @@ public class ProjectorDecimalTest extends org.apache.arrow.gandiva.evaluator.Bas
       eval.close();
     }
   }
+
+  @Test
+  public void testCastToString() throws GandivaException {
+    Decimal decimalType = new Decimal(38, 2);
+    Field dec = Field.nullable("dec", decimalType);
+    Field str = Field.nullable("str", new ArrowType.Utf8());
+    TreeNode field = TreeBuilder.makeField(dec);
+    TreeNode literal = TreeBuilder.makeLiteral(5L);
+    List<TreeNode> args = Lists.newArrayList(field, literal);
+    TreeNode cast = TreeBuilder.makeFunction("castVARCHAR", args, new ArrowType.Utf8());
+    TreeNode root = TreeBuilder.makeFunction("equal",
+        Lists.newArrayList(cast, TreeBuilder.makeField(str)), new ArrowType.Bool());
+    ExpressionTree tree = TreeBuilder.makeExpression(root, Field.nullable("are_equal", new ArrowType.Bool()));
+
+    Schema schema = new Schema(Lists.newArrayList(dec, str));
+    Projector eval = Projector.make(schema, Lists.newArrayList(tree)
+    );
+
+    List<ValueVector> output = null;
+    ArrowRecordBatch batch = null;
+    try {
+      int numRows = 4;
+      String[] aValues = new String[]{"10.51", "100.23", "-1000.23", "-0000.10"};
+      String[] expected = {"10.51", "100.2", "-1000", "-0.10"};
+      DecimalVector valuesa = decimalVector(aValues, decimalType.getPrecision(), decimalType.getScale());
+      VarCharVector result = varcharVector(expected);
+      batch = new ArrowRecordBatch(
+          numRows,
+          Lists.newArrayList(
+              new ArrowFieldNode(numRows, 0)
+          ),
+          Lists.newArrayList(
+              valuesa.getValidityBuffer(),
+              valuesa.getDataBuffer(),
+              result.getValidityBuffer(),
+              result.getOffsetBuffer(),
+              result.getDataBuffer()
+          )
+      );
+
+      BitVector resultVector = new BitVector("res", allocator);
+      resultVector.allocateNew();
+      output = new ArrayList<>(Arrays.asList(resultVector));
+
+      // evaluate expressions.
+      eval.evaluate(batch, output);
+
+      // compare the outputs.
+      for (int i = 0; i < numRows; i++) {
+        assertTrue(resultVector.getObject(i).booleanValue());
+      }
+    } finally {
+      // free buffers
+      if (batch != null) {
+        releaseRecordBatch(batch);
+      }
+      if (output != null) {
+        releaseValueVectors(output);
+      }
+      eval.close();
+    }
+  }
+
+  @Test
+  public void testCastStringToDecimal() throws GandivaException {
+    Decimal decimalType = new Decimal(4, 2);
+    Field dec = Field.nullable("dec", decimalType);
+
+    Field str = Field.nullable("str", new ArrowType.Utf8());
+    TreeNode field = TreeBuilder.makeField(str);
+    List<TreeNode> args = Lists.newArrayList(field);
+    TreeNode cast = TreeBuilder.makeFunction("castDECIMAL", args, decimalType);
+    ExpressionTree tree = TreeBuilder.makeExpression(cast, Field.nullable("dec_str", decimalType));
+
+    Schema schema = new Schema(Lists.newArrayList(str));
+    Projector eval = Projector.make(schema, Lists.newArrayList(tree)
+    );
+
+    List<ValueVector> output = null;
+    ArrowRecordBatch batch = null;
+    try {
+      int numRows = 4;
+      String[] aValues = new String[]{"10.5134", "-0.1", "10.516", "-1000"};
+      VarCharVector valuesa = varcharVector(aValues);
+      batch = new ArrowRecordBatch(
+          numRows,
+          Lists.newArrayList(
+              new ArrowFieldNode(numRows, 0)
+          ),
+          Lists.newArrayList(
+              valuesa.getValidityBuffer(),
+              valuesa.getOffsetBuffer(),
+              valuesa.getDataBuffer()
+          )
+      );
+
+      DecimalVector resultVector = new DecimalVector("res", allocator,
+          decimalType.getPrecision(), decimalType.getScale());
+      resultVector.allocateNew();
+      output = new ArrayList<>(Arrays.asList(resultVector));
+
+      BigDecimal[] expected = {BigDecimal.valueOf(10.51), BigDecimal.valueOf(-0.10),
+          BigDecimal.valueOf(10.52), BigDecimal.valueOf(0.00)};
+      // evaluate expressions.
+      eval.evaluate(batch, output);
+
+      // compare the outputs.
+      for (int i = 0; i < numRows; i++) {
+        assertTrue("mismatch in result for " +
+            "field " + resultVector.getField().getName() +
+            " for row " + i +
+            " expected " + expected[i] +
+            ", got " + resultVector.getObject(i),expected[i].compareTo(resultVector.getObject(i)) == 0);
+      }
+    } finally {
+      // free buffers
+      if (batch != null) {
+        releaseRecordBatch(batch);
+      }
+      if (output != null) {
+        releaseValueVectors(output);
+      }
+      eval.close();
+    }
+  }
 }
+