You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2023/06/28 11:23:44 UTC

[arrow] branch main updated: GH-36182: [Gandiva][C++] Fix substring_index function when index is negative. (#36184)

This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new aa5592edc2 GH-36182: [Gandiva][C++] Fix substring_index function when index is negative. (#36184)
aa5592edc2 is described below

commit aa5592edc2771cd39adc3c0861a55c0489948cef
Author: lriggs <lo...@gmail.com>
AuthorDate: Wed Jun 28 04:23:37 2023 -0700

    GH-36182: [Gandiva][C++] Fix substring_index function when index is negative. (#36184)
    
    ### Rationale for this change
    
    substring_index("Abc.DE.fGh", '.', -2) returns "fGh" but it should return "DE.fGh" (ie starting from the second occurrence of the delimiter from the right). The proposed behavior matches the behavior of other databases.
    
    ### What changes are included in this PR?
    
    Fixed reverse index calculation and updated unit tests.
    
    ### Are these changes tested?
    
    Yes, unit tests and integration testing.
    
    ### Are there any user-facing changes?
    
    Function behavior change with substring_index function.
    
    * Closes: #36182
    
    Authored-by: Projjal Chanda <ia...@pchanda.com>
    Signed-off-by: Antoine Pitrou <an...@python.org>
---
 cpp/src/gandiva/gdv_function_stubs_test.cc         | 10 +++++++---
 cpp/src/gandiva/gdv_string_function_stubs.cc       |  7 +++++--
 cpp/src/gandiva/precompiled/arithmetic_ops_test.cc |  4 ++--
 cpp/src/gandiva/tests/projector_test.cc            |  3 ++-
 4 files changed, 16 insertions(+), 8 deletions(-)

diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc
index b4b717c023..a8dfcd088a 100644
--- a/cpp/src/gandiva/gdv_function_stubs_test.cc
+++ b/cpp/src/gandiva/gdv_function_stubs_test.cc
@@ -464,7 +464,7 @@ TEST(TestGdvFnStubs, TestSubstringIndex) {
   EXPECT_FALSE(ctx.has_error());
 
   out_str = gdv_fn_substring_index(ctx_ptr, "Abc.DE.fGh", 10, ".", 1, -2, &out_len);
-  EXPECT_EQ(std::string(out_str, out_len), "fGh");
+  EXPECT_EQ(std::string(out_str, out_len), "DE.fGh");
   EXPECT_FALSE(ctx.has_error());
 
   out_str = gdv_fn_substring_index(ctx_ptr, "S;DCGS;JO!L", 11, ";", 1, 1, &out_len);
@@ -472,7 +472,7 @@ TEST(TestGdvFnStubs, TestSubstringIndex) {
   EXPECT_FALSE(ctx.has_error());
 
   out_str = gdv_fn_substring_index(ctx_ptr, "S;DCGS;JO!L", 11, ";", 1, -1, &out_len);
-  EXPECT_EQ(std::string(out_str, out_len), "DCGS;JO!L");
+  EXPECT_EQ(std::string(out_str, out_len), "JO!L");
   EXPECT_FALSE(ctx.has_error());
 
   out_str = gdv_fn_substring_index(ctx_ptr, "www.mysql.com", 13, "Q", 1, 1, &out_len);
@@ -496,7 +496,7 @@ TEST(TestGdvFnStubs, TestSubstringIndex) {
   EXPECT_FALSE(ctx.has_error());
 
   out_str = gdv_fn_substring_index(ctx_ptr, "www||mysql||com", 15, "||", 2, -2, &out_len);
-  EXPECT_EQ(std::string(out_str, out_len), "com");
+  EXPECT_EQ(std::string(out_str, out_len), "mysql||com");
   EXPECT_FALSE(ctx.has_error());
 
   out_str = gdv_fn_substring_index(ctx_ptr, "MÜNCHEN", 8, "Ü", 2, 1, &out_len);
@@ -507,6 +507,10 @@ TEST(TestGdvFnStubs, TestSubstringIndex) {
   EXPECT_EQ(std::string(out_str, out_len), "NCHEN");
   EXPECT_FALSE(ctx.has_error());
 
+  out_str = gdv_fn_substring_index(ctx_ptr, "MÜëCHEN", 9, "Ü", 2, -1, &out_len);
+  EXPECT_EQ(std::string(out_str, out_len), "ëCHEN");
+  EXPECT_FALSE(ctx.has_error());
+
   out_str = gdv_fn_substring_index(ctx_ptr, "citroën", 8, "ë", 2, -1, &out_len);
   EXPECT_EQ(std::string(out_str, out_len), "n");
   EXPECT_FALSE(ctx.has_error());
diff --git a/cpp/src/gandiva/gdv_string_function_stubs.cc b/cpp/src/gandiva/gdv_string_function_stubs.cc
index cf04de3a8e..3bfb297af1 100644
--- a/cpp/src/gandiva/gdv_string_function_stubs.cc
+++ b/cpp/src/gandiva/gdv_string_function_stubs.cc
@@ -413,10 +413,13 @@ const char* gdv_fn_substring_index(int64_t context, const char* txt, int32_t txt
     return out;
   } else if (static_cast<int32_t>(abs(cnt)) <= static_cast<int32_t>(occ.size()) &&
              cnt < 0) {
+    int32_t sz = static_cast<int32_t>(occ.size());
     int32_t temp = static_cast<int32_t>(abs(cnt));
-    memcpy(out, txt + occ[temp - 1] + pat_len, txt_len - occ[temp - 1] - pat_len);
-    *out_len = txt_len - occ[temp - 1] - pat_len;
+
+    memcpy(out, txt + occ[sz - temp] + pat_len, txt_len - occ[sz - temp] - pat_len);
+    *out_len = txt_len - occ[sz - temp] - pat_len;
     return out;
+
   } else {
     *out_len = txt_len;
     memcpy(out, txt, txt_len);
diff --git a/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc b/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc
index 64bcef34be..02fc68713b 100644
--- a/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc
+++ b/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc
@@ -681,14 +681,14 @@ TEST(TestArithmeticOps, TestCeilingFloatDouble) {
 }
 
 TEST(TestArithmeticOps, TestFloorFloatDouble) {
-  // ceiling from floats
+  // floor from floats
   EXPECT_EQ(floor_float32(6.6f), 6.0f);
   EXPECT_EQ(floor_float32(-6.6f), -7.0f);
   EXPECT_EQ(floor_float32(-6.3f), -7.0f);
   EXPECT_EQ(floor_float32(0.0f), 0.0f);
   EXPECT_EQ(floor_float32(-0), 0.0);
 
-  // ceiling from doubles
+  // floor from doubles
   EXPECT_EQ(floor_float64(6.6), 6.0);
   EXPECT_EQ(floor_float64(-6.6), -7.0);
   EXPECT_EQ(floor_float64(-6.3), -7.0);
diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc
index 25afa68c56..462fae6439 100644
--- a/cpp/src/gandiva/tests/projector_test.cc
+++ b/cpp/src/gandiva/tests/projector_test.cc
@@ -3202,7 +3202,8 @@ TEST_F(TestProjector, TestSubstringIndex) {
 
   auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array1, array2, array3});
 
-  auto out_1 = MakeArrowArrayUtf8({"www||mysql", "com", "DCGS;JO!L"}, {true, true, true});
+  auto out_1 =
+      MakeArrowArrayUtf8({"www||mysql", "mysql||com", "JO!L"}, {true, true, true});
 
   arrow::ArrayVector outputs;