You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ra...@apache.org on 2022/04/19 03:29:25 UTC

[arrow] branch master updated: ARROW-16186: [C++][GANDIVA] Add alias and tests for decimal, quarter, xor, etc...

This is an automated email from the ASF dual-hosted git repository.

ravindra pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 4798effe1a ARROW-16186: [C++][GANDIVA] Add alias and tests for decimal, quarter, xor, etc...
4798effe1a is described below

commit 4798effe1a11edec67596d4598411fd486a4288a
Author: Vinicius Roque <ho...@gmail.com>
AuthorDate: Tue Apr 19 08:59:12 2022 +0530

    ARROW-16186: [C++][GANDIVA] Add alias and tests for decimal, quarter, xor, etc...
    
    Closes #12875 from ViniciusSouzaRoque/feature/add-some-alias-and-tests
    
    Lead-authored-by: Vinicius Roque <ho...@gmail.com>
    Co-authored-by: ViniciusSouzaRoque <vi...@dremio.com>
    Co-authored-by: João Pedro <jo...@simbioseventures.com>
    Signed-off-by: Pindikura Ravindra <ra...@dremio.com>
---
 cpp/src/gandiva/function_registry_arithmetic.cc | 16 +++---
 cpp/src/gandiva/function_registry_datetime.cc   | 21 ++++----
 cpp/src/gandiva/function_registry_string.cc     | 32 +++++------
 cpp/src/gandiva/tests/projector_test.cc         | 72 +++++++++++++++++++++++++
 4 files changed, 107 insertions(+), 34 deletions(-)

diff --git a/cpp/src/gandiva/function_registry_arithmetic.cc b/cpp/src/gandiva/function_registry_arithmetic.cc
index 983f7fe397..829ca3e1bc 100644
--- a/cpp/src/gandiva/function_registry_arithmetic.cc
+++ b/cpp/src/gandiva/function_registry_arithmetic.cc
@@ -82,12 +82,12 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
       UNARY_CAST_TO_FLOAT64(float32), UNARY_CAST_TO_FLOAT64(decimal128),
 
       // cast to decimal
-      UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, int32, decimal128),
-      UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, int64, decimal128),
-      UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, float32, decimal128),
-      UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, float64, decimal128),
-      UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, decimal128, decimal128),
-      UNARY_UNSAFE_NULL_IF_NULL(castDECIMAL, {}, utf8, decimal128),
+      UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {"decimal"}, int32, decimal128),
+      UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {"decimal"}, int64, decimal128),
+      UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {"decimal"}, float32, decimal128),
+      UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {"decimal"}, float64, decimal128),
+      UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {"decimal"}, decimal128, decimal128),
+      UNARY_UNSAFE_NULL_IF_NULL(castDECIMAL, {"decimal"}, utf8, decimal128),
 
       NativeFunction("castDECIMALNullOnOverflow", {}, DataTypeVector{decimal128()},
                      decimal128(), kResultNullInternal,
@@ -119,8 +119,8 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
       BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_and, {}, int64),
       BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_or, {}, int32),
       BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_or, {}, int64),
-      BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_xor, {}, int32),
-      BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_xor, {}, int64),
+      BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_xor, {"xor"}, int32),
+      BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_xor, {"xor"}, int64),
       UNARY_SAFE_NULL_IF_NULL(bitwise_not, {}, int32, int32),
       UNARY_SAFE_NULL_IF_NULL(bitwise_not, {}, int64, int64),
 
diff --git a/cpp/src/gandiva/function_registry_datetime.cc b/cpp/src/gandiva/function_registry_datetime.cc
index 1b428492b6..be9cadbb80 100644
--- a/cpp/src/gandiva/function_registry_datetime.cc
+++ b/cpp/src/gandiva/function_registry_datetime.cc
@@ -21,14 +21,15 @@
 
 namespace gandiva {
 
-#define DATE_EXTRACTION_TRUNCATION_FNS(INNER, name)                                    \
-  DATE_TYPES(INNER, name##Millennium, {}), DATE_TYPES(INNER, name##Century, {}),       \
-      DATE_TYPES(INNER, name##Decade, {}), DATE_TYPES(INNER, name##Year, {"year"}),    \
-      DATE_TYPES(INNER, name##Quarter, {}), DATE_TYPES(INNER, name##Month, {"month"}), \
-      DATE_TYPES(INNER, name##Week, ({"weekofyear", "yearweek"})),                     \
-      DATE_TYPES(INNER, name##Day, ({"day", "dayofmonth"})),                           \
-      DATE_TYPES(INNER, name##Hour, {"hour"}),                                         \
-      DATE_TYPES(INNER, name##Minute, {"minute"}),                                     \
+#define DATE_EXTRACTION_TRUNCATION_FNS(INNER, name)                                 \
+  DATE_TYPES(INNER, name##Millennium, {}), DATE_TYPES(INNER, name##Century, {}),    \
+      DATE_TYPES(INNER, name##Decade, {}), DATE_TYPES(INNER, name##Year, {"year"}), \
+      DATE_TYPES(INNER, name##Quarter, ({"quarter"})),                              \
+      DATE_TYPES(INNER, name##Month, {"month"}),                                    \
+      DATE_TYPES(INNER, name##Week, ({"weekofyear", "yearweek"})),                  \
+      DATE_TYPES(INNER, name##Day, ({"day", "dayofmonth"})),                        \
+      DATE_TYPES(INNER, name##Hour, {"hour"}),                                      \
+      DATE_TYPES(INNER, name##Minute, {"minute"}),                                  \
       DATE_TYPES(INNER, name##Second, {"second"})
 
 #define TO_TIMESTAMP_SAFE_NULL_IF_NULL(NAME, ALIASES, TYPE)                       \
@@ -65,8 +66,8 @@ std::vector<NativeFunction> GetDateTimeFunctionRegistry() {
                      kResultNullIfNull, "castTIMESTAMP_utf8",
                      NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
 
-      NativeFunction("castVARCHAR", {}, DataTypeVector{timestamp(), int64()}, utf8(),
-                     kResultNullIfNull, "castVARCHAR_timestamp_int64",
+      NativeFunction("castVARCHAR", {"varchar"}, DataTypeVector{timestamp(), int64()},
+                     utf8(), kResultNullIfNull, "castVARCHAR_timestamp_int64",
                      NativeFunction::kNeedsContext),
 
       NativeFunction("to_date", {}, DataTypeVector{utf8(), utf8()}, date64(),
diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc
index b0c52a2295..a56d8f07d1 100644
--- a/cpp/src/gandiva/function_registry_string.cc
+++ b/cpp/src/gandiva/function_registry_string.cc
@@ -52,7 +52,7 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
 
       UNARY_OCTET_LEN_FN(octet_length, {}), UNARY_OCTET_LEN_FN(bit_length, {}),
 
-      UNARY_UNSAFE_NULL_IF_NULL(char_length, {}, utf8, int32),
+      UNARY_UNSAFE_NULL_IF_NULL(char_length, {"character_length"}, utf8, int32),
       UNARY_UNSAFE_NULL_IF_NULL(length, {}, utf8, int32),
       UNARY_UNSAFE_NULL_IF_NULL(lengthUtf8, {}, binary, int32),
       UNARY_UNSAFE_NULL_IF_NULL(reverse, {}, utf8, utf8),
@@ -163,40 +163,40 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
                      kResultNullIfNull, "gdv_fn_castFLOAT8_varbinary",
                      NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
 
-      NativeFunction("castVARCHAR", {}, DataTypeVector{boolean(), int64()}, utf8(),
-                     kResultNullIfNull, "castVARCHAR_bool_int64",
+      NativeFunction("castVARCHAR", {"varchar"}, DataTypeVector{boolean(), int64()},
+                     utf8(), kResultNullIfNull, "castVARCHAR_bool_int64",
                      NativeFunction::kNeedsContext),
 
-      NativeFunction("castVARCHAR", {}, DataTypeVector{utf8(), int64()}, utf8(),
+      NativeFunction("castVARCHAR", {"varchar"}, DataTypeVector{utf8(), int64()}, utf8(),
                      kResultNullIfNull, "castVARCHAR_utf8_int64",
                      NativeFunction::kNeedsContext),
 
-      NativeFunction("castVARCHAR", {}, DataTypeVector{binary(), int64()}, utf8(),
-                     kResultNullIfNull, "castVARCHAR_binary_int64",
+      NativeFunction("castVARCHAR", {"varchar"}, DataTypeVector{binary(), int64()},
+                     utf8(), kResultNullIfNull, "castVARCHAR_binary_int64",
                      NativeFunction::kNeedsContext),
 
-      NativeFunction("castVARCHAR", {}, DataTypeVector{int32(), int64()}, utf8(),
+      NativeFunction("castVARCHAR", {"varchar"}, DataTypeVector{int32(), int64()}, utf8(),
                      kResultNullIfNull, "gdv_fn_castVARCHAR_int32_int64",
                      NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
 
-      NativeFunction("castVARCHAR", {}, DataTypeVector{int64(), int64()}, utf8(),
+      NativeFunction("castVARCHAR", {"varchar"}, DataTypeVector{int64(), int64()}, utf8(),
                      kResultNullIfNull, "gdv_fn_castVARCHAR_int64_int64",
                      NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
 
-      NativeFunction("castVARCHAR", {}, DataTypeVector{date64(), int64()}, utf8(),
-                     kResultNullIfNull, "gdv_fn_castVARCHAR_date64_int64",
+      NativeFunction("castVARCHAR", {"varchar"}, DataTypeVector{date64(), int64()},
+                     utf8(), kResultNullIfNull, "gdv_fn_castVARCHAR_date64_int64",
                      NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
 
-      NativeFunction("castVARCHAR", {}, DataTypeVector{float32(), int64()}, utf8(),
-                     kResultNullIfNull, "gdv_fn_castVARCHAR_float32_int64",
+      NativeFunction("castVARCHAR", {"varchar"}, DataTypeVector{float32(), int64()},
+                     utf8(), kResultNullIfNull, "gdv_fn_castVARCHAR_float32_int64",
                      NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
 
-      NativeFunction("castVARCHAR", {}, DataTypeVector{float64(), int64()}, utf8(),
-                     kResultNullIfNull, "gdv_fn_castVARCHAR_float64_int64",
+      NativeFunction("castVARCHAR", {"varchar"}, DataTypeVector{float64(), int64()},
+                     utf8(), kResultNullIfNull, "gdv_fn_castVARCHAR_float64_int64",
                      NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
 
-      NativeFunction("castVARCHAR", {}, DataTypeVector{decimal128(), int64()}, utf8(),
-                     kResultNullIfNull, "castVARCHAR_decimal128_int64",
+      NativeFunction("castVARCHAR", {"varchar"}, DataTypeVector{decimal128(), int64()},
+                     utf8(), kResultNullIfNull, "castVARCHAR_decimal128_int64",
                      NativeFunction::kNeedsContext),
 
       NativeFunction("crc32", {}, DataTypeVector{utf8()}, int64(), kResultNullIfNull,
diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc
index c65504b5b8..7a040502f7 100644
--- a/cpp/src/gandiva/tests/projector_test.cc
+++ b/cpp/src/gandiva/tests/projector_test.cc
@@ -784,6 +784,41 @@ TEST_F(TestProjector, TestDivideZero) {
   EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
 }
 
+TEST_F(TestProjector, TestXor) {
+  // schema for input fields
+  auto field0 = field("f0", int32());
+  auto field1 = field("f1", int32());
+  auto schema = arrow::schema({field0, field1});
+
+  // output fields
+  auto field_xor = field("xor", int32());
+
+  // Build expression
+  auto mod_expr = TreeExprBuilder::MakeExpression("xor", {field0, field1}, field_xor);
+
+  std::shared_ptr<Projector> projector;
+  auto status = Projector::Make(schema, {mod_expr}, TestConfiguration(), &projector);
+  EXPECT_TRUE(status.ok()) << status.message();
+
+  // Create a row-batch with some sample data
+  int num_records = 4;
+  auto array0 = MakeArrowArrayInt32({2, 3, 1, 20}, {true, true, true, true});
+  auto array1 = MakeArrowArrayInt32({4, 1, 3, 0}, {true, true, true, true});
+  // expected output
+  auto exp_mod = MakeArrowArrayInt32({6, 2, 2, 20}, {true, true, true, true});
+
+  // prepare input record batch
+  auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});
+
+  // Evaluate expression
+  arrow::ArrayVector outputs;
+  status = projector->Evaluate(*in_batch, pool_, &outputs);
+  EXPECT_TRUE(status.ok()) << status.message();
+
+  // Validate results
+  EXPECT_ARROW_ARRAY_EQUALS(exp_mod, outputs.at(0));
+}
+
 TEST_F(TestProjector, TestSoundex) {
   // schema for input fields
   auto field0 = field("f0", arrow::utf8());
@@ -2011,6 +2046,43 @@ TEST_F(TestProjector, TestDayOfMonth) {
   EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
 }
 
+TEST_F(TestProjector, TestQuarter) {
+  // input fields
+  // schema for input fields
+  auto field0 = field("f0", arrow::date64());
+  auto schema = arrow::schema({field0});
+
+  // output fields
+  auto field_result = field("quarter", arrow::int64());
+
+  // Build expression
+  auto myexpr = TreeExprBuilder::MakeExpression("quarter", {field0}, field_result);
+
+  // Build a projector for the expressions.
+  std::shared_ptr<Projector> projector;
+  auto status = Projector::Make(schema, {myexpr}, TestConfiguration(), &projector);
+  EXPECT_TRUE(status.ok());
+
+  // Create a row-batch with some sample data
+  int num_records = 4;
+  auto array0 =
+      MakeArrowArrayDate64({1604293200000, 1409648400000, 921783012000, 1338369900000},
+                           {true, true, true, true});
+  // expected output
+  auto exp = MakeArrowArrayInt64({4, 3, 1, 2}, {true, true, true, true});
+
+  // prepare input record batch
+  auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+  // Evaluate expression
+  arrow::ArrayVector outputs;
+  status = projector->Evaluate(*in_batch, pool_, &outputs);
+  EXPECT_TRUE(status.ok());
+
+  // Validate results
+  EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
 TEST_F(TestProjector, TestBround) {
   // schema for input fields
   auto field0 = field("f0", arrow::float64());