You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ra...@apache.org on 2022/07/20 09:27:19 UTC
[arrow] branch master updated: ARROW-17067: [C++][Gandiva] Implement Substring_Index Function. (#13600)
This is an automated email from the ASF dual-hosted git repository.
ravindra pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 188efb7bda ARROW-17067: [C++][Gandiva] Implement Substring_Index Function. (#13600)
188efb7bda is described below
commit 188efb7bdafd7fe1541393eddfddcf5f1cb61fa9
Author: Sahaj Gupta <10...@users.noreply.github.com>
AuthorDate: Wed Jul 20 14:57:13 2022 +0530
ARROW-17067: [C++][Gandiva] Implement Substring_Index Function. (#13600)
Adding Substring_Index Function.
Authored-by: SG011 <sa...@dremio.com>
Signed-off-by: Pindikura Ravindra <ra...@dremio.com>
---
cpp/src/gandiva/function_registry_string.cc | 5 +-
cpp/src/gandiva/gdv_function_stubs.h | 5 ++
cpp/src/gandiva/gdv_function_stubs_test.cc | 71 ++++++++++++++++++
cpp/src/gandiva/gdv_string_function_stubs.cc | 103 +++++++++++++++++++++++++++
cpp/src/gandiva/tests/projector_test.cc | 43 +++++++++++
5 files changed, 226 insertions(+), 1 deletion(-)
diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc
index 21681775c6..c7f4c62e49 100644
--- a/cpp/src/gandiva/function_registry_string.cc
+++ b/cpp/src/gandiva/function_registry_string.cc
@@ -515,8 +515,11 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
NativeFunction("translate", {}, DataTypeVector{utf8(), utf8(), utf8()}, utf8(),
kResultNullIfNull, "translate_utf8_utf8_utf8",
- NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors)};
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+ NativeFunction("substring_index", {}, DataTypeVector{utf8(), utf8(), int32()},
+ utf8(), kResultNullIfNull, "gdv_fn_substring_index",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors)};
return string_fn_registry_;
}
diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h
index c89720aa8c..42d360cfd5 100644
--- a/cpp/src/gandiva/gdv_function_stubs.h
+++ b/cpp/src/gandiva/gdv_function_stubs.h
@@ -346,4 +346,9 @@ gdv_timestamp to_utc_timezone_timestamp(int64_t context, gdv_timestamp time_mili
GANDIVA_EXPORT
gdv_timestamp from_utc_timezone_timestamp(int64_t context, gdv_timestamp time_miliseconds,
const char* timezone, int32_t length);
+
+GANDIVA_EXPORT
+const char* gdv_fn_substring_index(int64_t context, const char* txt, int32_t txt_len,
+ const char* pat, int32_t pat_len, int32_t cnt,
+ int32_t* out_len);
}
diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc
index f35e25d5ad..af6c8d4092 100644
--- a/cpp/src/gandiva/gdv_function_stubs_test.cc
+++ b/cpp/src/gandiva/gdv_function_stubs_test.cc
@@ -453,6 +453,77 @@ TEST(TestGdvFnStubs, TestCastVARCHARFromDouble) {
EXPECT_FALSE(ctx.has_error());
}
+TEST(TestGdvFnStubs, TestSubstringIndex) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+ gdv_int32 out_len = 0;
+
+ const char* out_str =
+ gdv_fn_substring_index(ctx_ptr, "Abc.DE.fGh", 10, ".", 1, 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "Abc.DE");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_substring_index(ctx_ptr, "Abc.DE.fGh", 10, ".", 1, -2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "fGh");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_substring_index(ctx_ptr, "S;DCGS;JO!L", 11, ";", 1, 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "S");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_substring_index(ctx_ptr, "S;DCGS;JO!L", 11, ";", 1, -1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "DCGS;JO!L");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_substring_index(ctx_ptr, "www.mysql.com", 13, "Q", 1, 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "www.mysql.com");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_substring_index(ctx_ptr, "www||mysql||com", 15, "||", 2, 2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "www||mysql");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_substring_index(ctx_ptr, "", 0, ".", 1, 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len).size(), 0);
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_substring_index(ctx_ptr, "www||mysql||com", 15, "", 0, 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len).size(), 0);
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_substring_index(ctx_ptr, "www||mysql||com", 15, "||", 2, 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len).size(), 0);
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_substring_index(ctx_ptr, "www||mysql||com", 15, "||", 2, -2, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "com");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_substring_index(ctx_ptr, "MÜNCHEN", 8, "Ü", 2, 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "M");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_substring_index(ctx_ptr, "MÜNCHEN", 8, "Ü", 2, -1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "NCHEN");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_substring_index(ctx_ptr, "citroën", 8, "ë", 2, -1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "n");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_substring_index(ctx_ptr, "citroën", 8, "ë", 2, 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "citro");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_substring_index(ctx_ptr, "路学\\L", 8, "\\", 1, 1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "路学");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_substring_index(ctx_ptr, "路学\\L", 8, "\\", 1, -1, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "L");
+ EXPECT_FALSE(ctx.has_error());
+}
+
TEST(TestGdvFnStubs, TestUpper) {
gandiva::ExecutionContext ctx;
uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
diff --git a/cpp/src/gandiva/gdv_string_function_stubs.cc b/cpp/src/gandiva/gdv_string_function_stubs.cc
index f4bc5f8462..0c963f4417 100644
--- a/cpp/src/gandiva/gdv_string_function_stubs.cc
+++ b/cpp/src/gandiva/gdv_string_function_stubs.cc
@@ -336,6 +336,94 @@ const char* gdv_fn_upper_utf8(int64_t context, const char* data, int32_t data_le
return out;
}
+// Substring_index
+GDV_FORCE_INLINE
+const char* gdv_fn_substring_index(int64_t context, const char* txt, int32_t txt_len,
+ const char* pat, int32_t pat_len, int32_t cnt,
+ int32_t* out_len) {
+ if (txt_len == 0 || pat_len == 0 || cnt == 0) {
+ *out_len = 0;
+ return "";
+ }
+
+ char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, txt_len));
+ if (out == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+
+ std::vector<int> lps(pat_len);
+ int len = 0;
+
+ lps[0] = 0; // lps[0] is always 0
+
+ // the loop calculates lps[i] for i = 1 to M-1
+ int i = 1;
+ while (i < pat_len) {
+ if (pat[i] == pat[len]) {
+ len++;
+ lps[i] = len;
+ i++;
+ } else {
+ // (pat[i] != pat[len])
+ // This is tricky. Consider the example.
+ // AAACAAAA and i = 7. The idea is similar
+ // to search step.
+ if (len != 0) {
+ len = lps[len - 1];
+
+ // Also, note that we do not increment
+ // i here
+ } else {
+ // if (len == 0)
+ lps[i] = 0;
+ i++;
+ }
+ }
+ }
+
+ std::vector<int> occ;
+
+ i = 0; // index for txt[]
+ int j = 0; // index for pat[]
+ while (i < txt_len) {
+ if (pat[j] == txt[i]) {
+ j++;
+ i++;
+ }
+
+ if (j == pat_len) {
+ occ.push_back(i - j);
+ j = lps[j - 1];
+ } else if (i < txt_len && pat[j] != txt[i]) {
+ // mismatch after j matches
+ // Do not match lps[0..lps[j-1]] characters,
+ // they will match anyway
+ if (j != 0)
+ j = lps[j - 1];
+ else
+ i = i + 1;
+ }
+ }
+
+ if (static_cast<int32_t>(abs(cnt)) <= static_cast<int32_t>(occ.size()) && cnt > 0) {
+ memcpy(out, txt, occ[cnt - 1]);
+ *out_len = occ[cnt - 1];
+ return out;
+ } else if (static_cast<int32_t>(abs(cnt)) <= static_cast<int32_t>(occ.size()) &&
+ cnt < 0) {
+ int32_t temp = static_cast<int32_t>(abs(cnt));
+ memcpy(out, txt + occ[temp - 1] + pat_len, txt_len - occ[temp - 1] - pat_len);
+ *out_len = txt_len - occ[temp - 1] - pat_len;
+ return out;
+ } else {
+ *out_len = txt_len;
+ memcpy(out, txt, txt_len);
+ return out;
+ }
+}
+
// Any codepoint, except the ones for lowercase letters, uppercase letters,
// titlecase letters, decimal digits and letter numbers categories will be
// considered as word separators.
@@ -855,6 +943,21 @@ void ExportedStringFunctions::AddMappings(Engine* engine) const {
types->i8_ptr_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_upper_utf8));
+ // gdv_fn_substring_index
+ args = {
+ types->i64_type(), // context
+ types->i8_ptr_type(), // txt
+ types->i32_type(), // txt_len
+ types->i8_ptr_type(), // pat
+ types->i32_type(), // pat_len
+ types->i32_type(), // cnt
+ types->i32_ptr_type(), // out_len
+ };
+
+ engine->AddGlobalMappingForFunc("gdv_fn_substring_index",
+ types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_substring_index));
+
// gdv_fn_initcap_utf8
args = {
types->i64_type(), // context
diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc
index 39b0df8b90..69a5ce7b2d 100644
--- a/cpp/src/gandiva/tests/projector_test.cc
+++ b/cpp/src/gandiva/tests/projector_test.cc
@@ -2992,6 +2992,49 @@ TEST_F(TestProjector, TestUCase) {
EXPECT_ARROW_ARRAY_EQUALS(out_1, outputs.at(0));
}
+TEST_F(TestProjector, TestSubstringIndex) {
+ auto field1 = field("f1", arrow::utf8());
+ auto field2 = field("f2", arrow::utf8());
+ auto field3 = field("f3", arrow::int32());
+ auto schema = arrow::schema({field1, field2, field3});
+
+ // output fields
+ auto substring_index = field("substring", arrow::utf8());
+
+ // Build expression
+ auto substring_expr = TreeExprBuilder::MakeExpression(
+ "substring_index", {field1, field2, field3}, substring_index);
+
+ std::shared_ptr<Projector> projector;
+
+ auto status =
+ Projector::Make(schema, {substring_expr}, TestConfiguration(), &projector);
+
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 3;
+
+ auto array1 = MakeArrowArrayUtf8({"www||mysql||com", "www||mysql||com", "S;DCGS;JO!L"},
+ {true, true, true});
+
+ auto array2 = MakeArrowArrayUtf8({"||", "||", ";"}, {true, true, true});
+
+ auto array3 = MakeArrowArrayInt32({2, -2, -1}, {true, true, true});
+
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array1, array2, array3});
+
+ auto out_1 = MakeArrowArrayUtf8({"www||mysql", "com", "DCGS;JO!L"}, {true, true, true});
+
+ arrow::ArrayVector outputs;
+
+ // Evaluate expression
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok());
+
+ EXPECT_ARROW_ARRAY_EQUALS(out_1, outputs.at(0));
+}
+
TEST_F(TestProjector, TestLCase) {
auto field0 = field("f0", arrow::utf8());
auto schema = arrow::schema({field0});