You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ra...@apache.org on 2022/07/20 09:27:19 UTC

[arrow] branch master updated: ARROW-17067: [C++][Gandiva] Implement Substring_Index Function. (#13600)

This is an automated email from the ASF dual-hosted git repository.

ravindra pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 188efb7bda ARROW-17067: [C++][Gandiva] Implement Substring_Index Function. (#13600)
188efb7bda is described below

commit 188efb7bdafd7fe1541393eddfddcf5f1cb61fa9
Author: Sahaj Gupta <10...@users.noreply.github.com>
AuthorDate: Wed Jul 20 14:57:13 2022 +0530

    ARROW-17067: [C++][Gandiva] Implement Substring_Index Function. (#13600)
    
    Adding Substring_Index Function.
    
    Authored-by: SG011 <sa...@dremio.com>
    Signed-off-by: Pindikura Ravindra <ra...@dremio.com>
---
 cpp/src/gandiva/function_registry_string.cc  |   5 +-
 cpp/src/gandiva/gdv_function_stubs.h         |   5 ++
 cpp/src/gandiva/gdv_function_stubs_test.cc   |  71 ++++++++++++++++++
 cpp/src/gandiva/gdv_string_function_stubs.cc | 103 +++++++++++++++++++++++++++
 cpp/src/gandiva/tests/projector_test.cc      |  43 +++++++++++
 5 files changed, 226 insertions(+), 1 deletion(-)

diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc
index 21681775c6..c7f4c62e49 100644
--- a/cpp/src/gandiva/function_registry_string.cc
+++ b/cpp/src/gandiva/function_registry_string.cc
@@ -515,8 +515,11 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
 
       NativeFunction("translate", {}, DataTypeVector{utf8(), utf8(), utf8()}, utf8(),
                      kResultNullIfNull, "translate_utf8_utf8_utf8",
-                     NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors)};
+                     NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
 
+      NativeFunction("substring_index", {}, DataTypeVector{utf8(), utf8(), int32()},
+                     utf8(), kResultNullIfNull, "gdv_fn_substring_index",
+                     NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors)};
   return string_fn_registry_;
 }
 
diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h
index c89720aa8c..42d360cfd5 100644
--- a/cpp/src/gandiva/gdv_function_stubs.h
+++ b/cpp/src/gandiva/gdv_function_stubs.h
@@ -346,4 +346,9 @@ gdv_timestamp to_utc_timezone_timestamp(int64_t context, gdv_timestamp time_mili
 GANDIVA_EXPORT
 gdv_timestamp from_utc_timezone_timestamp(int64_t context, gdv_timestamp time_miliseconds,
                                           const char* timezone, int32_t length);
+
+GANDIVA_EXPORT
+const char* gdv_fn_substring_index(int64_t context, const char* txt, int32_t txt_len,
+                                   const char* pat, int32_t pat_len, int32_t cnt,
+                                   int32_t* out_len);
 }
diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc
index f35e25d5ad..af6c8d4092 100644
--- a/cpp/src/gandiva/gdv_function_stubs_test.cc
+++ b/cpp/src/gandiva/gdv_function_stubs_test.cc
@@ -453,6 +453,77 @@ TEST(TestGdvFnStubs, TestCastVARCHARFromDouble) {
   EXPECT_FALSE(ctx.has_error());
 }
 
+TEST(TestGdvFnStubs, TestSubstringIndex) {
+  gandiva::ExecutionContext ctx;
+  uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+  gdv_int32 out_len = 0;
+
+  const char* out_str =
+      gdv_fn_substring_index(ctx_ptr, "Abc.DE.fGh", 10, ".", 1, 2, &out_len);
+  EXPECT_EQ(std::string(out_str, out_len), "Abc.DE");
+  EXPECT_FALSE(ctx.has_error());
+
+  out_str = gdv_fn_substring_index(ctx_ptr, "Abc.DE.fGh", 10, ".", 1, -2, &out_len);
+  EXPECT_EQ(std::string(out_str, out_len), "fGh");
+  EXPECT_FALSE(ctx.has_error());
+
+  out_str = gdv_fn_substring_index(ctx_ptr, "S;DCGS;JO!L", 11, ";", 1, 1, &out_len);
+  EXPECT_EQ(std::string(out_str, out_len), "S");
+  EXPECT_FALSE(ctx.has_error());
+
+  out_str = gdv_fn_substring_index(ctx_ptr, "S;DCGS;JO!L", 11, ";", 1, -1, &out_len);
+  EXPECT_EQ(std::string(out_str, out_len), "DCGS;JO!L");
+  EXPECT_FALSE(ctx.has_error());
+
+  out_str = gdv_fn_substring_index(ctx_ptr, "www.mysql.com", 13, "Q", 1, 1, &out_len);
+  EXPECT_EQ(std::string(out_str, out_len), "www.mysql.com");
+  EXPECT_FALSE(ctx.has_error());
+
+  out_str = gdv_fn_substring_index(ctx_ptr, "www||mysql||com", 15, "||", 2, 2, &out_len);
+  EXPECT_EQ(std::string(out_str, out_len), "www||mysql");
+  EXPECT_FALSE(ctx.has_error());
+
+  out_str = gdv_fn_substring_index(ctx_ptr, "", 0, ".", 1, 1, &out_len);
+  EXPECT_EQ(std::string(out_str, out_len).size(), 0);
+  EXPECT_FALSE(ctx.has_error());
+
+  out_str = gdv_fn_substring_index(ctx_ptr, "www||mysql||com", 15, "", 0, 1, &out_len);
+  EXPECT_EQ(std::string(out_str, out_len).size(), 0);
+  EXPECT_FALSE(ctx.has_error());
+
+  out_str = gdv_fn_substring_index(ctx_ptr, "www||mysql||com", 15, "||", 2, 0, &out_len);
+  EXPECT_EQ(std::string(out_str, out_len).size(), 0);
+  EXPECT_FALSE(ctx.has_error());
+
+  out_str = gdv_fn_substring_index(ctx_ptr, "www||mysql||com", 15, "||", 2, -2, &out_len);
+  EXPECT_EQ(std::string(out_str, out_len), "com");
+  EXPECT_FALSE(ctx.has_error());
+
+  out_str = gdv_fn_substring_index(ctx_ptr, "MÜNCHEN", 8, "Ü", 2, 1, &out_len);
+  EXPECT_EQ(std::string(out_str, out_len), "M");
+  EXPECT_FALSE(ctx.has_error());
+
+  out_str = gdv_fn_substring_index(ctx_ptr, "MÜNCHEN", 8, "Ü", 2, -1, &out_len);
+  EXPECT_EQ(std::string(out_str, out_len), "NCHEN");
+  EXPECT_FALSE(ctx.has_error());
+
+  out_str = gdv_fn_substring_index(ctx_ptr, "citroën", 8, "ë", 2, -1, &out_len);
+  EXPECT_EQ(std::string(out_str, out_len), "n");
+  EXPECT_FALSE(ctx.has_error());
+
+  out_str = gdv_fn_substring_index(ctx_ptr, "citroën", 8, "ë", 2, 1, &out_len);
+  EXPECT_EQ(std::string(out_str, out_len), "citro");
+  EXPECT_FALSE(ctx.has_error());
+
+  out_str = gdv_fn_substring_index(ctx_ptr, "路学\\L", 8, "\\", 1, 1, &out_len);
+  EXPECT_EQ(std::string(out_str, out_len), "路学");
+  EXPECT_FALSE(ctx.has_error());
+
+  out_str = gdv_fn_substring_index(ctx_ptr, "路学\\L", 8, "\\", 1, -1, &out_len);
+  EXPECT_EQ(std::string(out_str, out_len), "L");
+  EXPECT_FALSE(ctx.has_error());
+}
+
 TEST(TestGdvFnStubs, TestUpper) {
   gandiva::ExecutionContext ctx;
   uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
diff --git a/cpp/src/gandiva/gdv_string_function_stubs.cc b/cpp/src/gandiva/gdv_string_function_stubs.cc
index f4bc5f8462..0c963f4417 100644
--- a/cpp/src/gandiva/gdv_string_function_stubs.cc
+++ b/cpp/src/gandiva/gdv_string_function_stubs.cc
@@ -336,6 +336,94 @@ const char* gdv_fn_upper_utf8(int64_t context, const char* data, int32_t data_le
   return out;
 }
 
+// Substring_index
+GDV_FORCE_INLINE
+const char* gdv_fn_substring_index(int64_t context, const char* txt, int32_t txt_len,
+                                   const char* pat, int32_t pat_len, int32_t cnt,
+                                   int32_t* out_len) {
+  if (txt_len == 0 || pat_len == 0 || cnt == 0) {
+    *out_len = 0;
+    return "";
+  }
+
+  char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, txt_len));
+  if (out == nullptr) {
+    gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+    *out_len = 0;
+    return "";
+  }
+
+  std::vector<int> lps(pat_len);
+  int len = 0;
+
+  lps[0] = 0;  // lps[0] is always 0
+
+  // the loop calculates lps[i] for i = 1 to M-1
+  int i = 1;
+  while (i < pat_len) {
+    if (pat[i] == pat[len]) {
+      len++;
+      lps[i] = len;
+      i++;
+    } else {
+      // (pat[i] != pat[len])
+      // This is tricky. Consider the example.
+      // AAACAAAA and i = 7. The idea is similar
+      // to search step.
+      if (len != 0) {
+        len = lps[len - 1];
+
+        // Also, note that we do not increment
+        // i here
+      } else {
+        // if (len == 0)
+        lps[i] = 0;
+        i++;
+      }
+    }
+  }
+
+  std::vector<int> occ;
+
+  i = 0;      // index for txt[]
+  int j = 0;  // index for pat[]
+  while (i < txt_len) {
+    if (pat[j] == txt[i]) {
+      j++;
+      i++;
+    }
+
+    if (j == pat_len) {
+      occ.push_back(i - j);
+      j = lps[j - 1];
+    } else if (i < txt_len && pat[j] != txt[i]) {
+      // mismatch after j matches
+      // Do not match lps[0..lps[j-1]] characters,
+      // they will match anyway
+      if (j != 0)
+        j = lps[j - 1];
+      else
+        i = i + 1;
+    }
+  }
+
+  if (static_cast<int32_t>(abs(cnt)) <= static_cast<int32_t>(occ.size()) && cnt > 0) {
+    memcpy(out, txt, occ[cnt - 1]);
+    *out_len = occ[cnt - 1];
+    return out;
+  } else if (static_cast<int32_t>(abs(cnt)) <= static_cast<int32_t>(occ.size()) &&
+             cnt < 0) {
+    int32_t temp = static_cast<int32_t>(abs(cnt));
+    memcpy(out, txt + occ[temp - 1] + pat_len, txt_len - occ[temp - 1] - pat_len);
+    *out_len = txt_len - occ[temp - 1] - pat_len;
+    return out;
+  } else {
+    *out_len = txt_len;
+    memcpy(out, txt, txt_len);
+    return out;
+  }
+}
+
 // Any codepoint, except the ones for lowercase letters, uppercase letters,
 // titlecase letters, decimal digits and letter numbers categories will be
 // considered as word separators.
@@ -855,6 +943,21 @@ void ExportedStringFunctions::AddMappings(Engine* engine) const {
                                   types->i8_ptr_type() /*return_type*/, args,
                                   reinterpret_cast<void*>(gdv_fn_upper_utf8));
 
+  // gdv_fn_substring_index
+  args = {
+      types->i64_type(),      // context
+      types->i8_ptr_type(),   // txt
+      types->i32_type(),      // txt_len
+      types->i8_ptr_type(),   // pat
+      types->i32_type(),      // pat_len
+      types->i32_type(),      // cnt
+      types->i32_ptr_type(),  // out_len
+  };
+
+  engine->AddGlobalMappingForFunc("gdv_fn_substring_index",
+                                  types->i8_ptr_type() /*return_type*/, args,
+                                  reinterpret_cast<void*>(gdv_fn_substring_index));
+
   // gdv_fn_initcap_utf8
   args = {
       types->i64_type(),     // context
diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc
index 39b0df8b90..69a5ce7b2d 100644
--- a/cpp/src/gandiva/tests/projector_test.cc
+++ b/cpp/src/gandiva/tests/projector_test.cc
@@ -2992,6 +2992,49 @@ TEST_F(TestProjector, TestUCase) {
   EXPECT_ARROW_ARRAY_EQUALS(out_1, outputs.at(0));
 }
 
+TEST_F(TestProjector, TestSubstringIndex) {
+  auto field1 = field("f1", arrow::utf8());
+  auto field2 = field("f2", arrow::utf8());
+  auto field3 = field("f3", arrow::int32());
+  auto schema = arrow::schema({field1, field2, field3});
+
+  // output fields
+  auto substring_index = field("substring", arrow::utf8());
+
+  // Build expression
+  auto substring_expr = TreeExprBuilder::MakeExpression(
+      "substring_index", {field1, field2, field3}, substring_index);
+
+  std::shared_ptr<Projector> projector;
+
+  auto status =
+      Projector::Make(schema, {substring_expr}, TestConfiguration(), &projector);
+
+  EXPECT_TRUE(status.ok());
+
+  // Create a row-batch with some sample data
+  int num_records = 3;
+
+  auto array1 = MakeArrowArrayUtf8({"www||mysql||com", "www||mysql||com", "S;DCGS;JO!L"},
+                                   {true, true, true});
+
+  auto array2 = MakeArrowArrayUtf8({"||", "||", ";"}, {true, true, true});
+
+  auto array3 = MakeArrowArrayInt32({2, -2, -1}, {true, true, true});
+
+  auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array1, array2, array3});
+
+  auto out_1 = MakeArrowArrayUtf8({"www||mysql", "com", "DCGS;JO!L"}, {true, true, true});
+
+  arrow::ArrayVector outputs;
+
+  // Evaluate expression
+  status = projector->Evaluate(*in_batch, pool_, &outputs);
+  EXPECT_TRUE(status.ok());
+
+  EXPECT_ARROW_ARRAY_EQUALS(out_1, outputs.at(0));
+}
+
 TEST_F(TestProjector, TestLCase) {
   auto field0 = field("f0", arrow::utf8());
   auto schema = arrow::schema({field0});