You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2021/06/07 18:46:58 UTC
[arrow] branch master updated: ARROW-12950: [C++] Add
count_substring kernel
This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new d77e272 ARROW-12950: [C++] Add count_substring kernel
d77e272 is described below
commit d77e272c81d409c77a2ebe4127572f2ef44c7632
Author: David Li <li...@gmail.com>
AuthorDate: Mon Jun 7 20:45:36 2021 +0200
ARROW-12950: [C++] Add count_substring kernel
Depends on ARROW-12969. ignore_case is not included here; I'll include it with the regex variant in ARROW-12952.
Closes #10454 from lidavidm/arrow-12950
Authored-by: David Li <li...@gmail.com>
Signed-off-by: Antoine Pitrou <an...@python.org>
---
cpp/src/arrow/compute/kernels/scalar_string.cc | 70 +++++++++++++++++++++-
.../arrow/compute/kernels/scalar_string_test.cc | 19 ++++++
docs/source/cpp/compute.rst | 30 ++++++----
docs/source/python/api/compute.rst | 1 +
python/pyarrow/compute.py | 19 ++++++
python/pyarrow/tests/test_compute.py | 13 ++++
6 files changed, 139 insertions(+), 13 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc
index 154b57d..df3a399 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string.cc
@@ -741,8 +741,12 @@ template <typename InputType>
struct FindSubstringExec {
using OffsetType = typename TypeTraits<InputType>::OffsetType;
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const MatchSubstringOptions& options = MatchSubstringState::Get(ctx);
+ if (options.ignore_case) {
+ return Status::NotImplemented("find_substring with ignore_case");
+ }
applicator::ScalarUnaryNotNullStateful<OffsetType, InputType, FindSubstring> kernel{
- FindSubstring(PlainSubstringMatcher(MatchSubstringState::Get(ctx)))};
+ FindSubstring(PlainSubstringMatcher(options))};
return kernel.Exec(ctx, batch, out);
}
};
@@ -771,6 +775,69 @@ void AddFindSubstring(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));
}
+// Substring count
+
+struct CountSubstring {
+ const PlainSubstringMatcher matcher_;
+
+ explicit CountSubstring(PlainSubstringMatcher matcher) : matcher_(std::move(matcher)) {}
+
+ template <typename OutValue, typename... Ignored>
+ OutValue Call(KernelContext*, util::string_view val, Status*) const {
+ OutValue count = 0;
+ uint64_t start = 0;
+ const auto pattern_size = std::max<uint64_t>(1, matcher_.options_.pattern.size());
+ while (start <= val.size()) {
+ const int64_t index = matcher_.Find(val.substr(start));
+ if (index >= 0) {
+ count++;
+ start += index + pattern_size;
+ } else {
+ break;
+ }
+ }
+ return count;
+ }
+};
+
+template <typename InputType>
+struct CountSubstringExec {
+ using OffsetType = typename TypeTraits<InputType>::OffsetType;
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const MatchSubstringOptions& options = MatchSubstringState::Get(ctx);
+ if (options.ignore_case) {
+ return Status::NotImplemented("count_substring with ignore_case");
+ }
+ applicator::ScalarUnaryNotNullStateful<OffsetType, InputType, CountSubstring> kernel{
+ CountSubstring(PlainSubstringMatcher(options))};
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+const FunctionDoc count_substring_doc(
+ "Count occurrences of substring",
+ ("For each string in `strings`, emit the number of occurrences of the given "
+ "pattern.\n"
+ "Null inputs emit null. The pattern must be given in MatchSubstringOptions."),
+ {"strings"}, "MatchSubstringOptions");
+
+void AddCountSubstring(FunctionRegistry* registry) {
+ auto func = std::make_shared<ScalarFunction>("count_substring", Arity::Unary(),
+ &count_substring_doc);
+ for (const auto& ty : BaseBinaryTypes()) {
+ std::shared_ptr<DataType> offset_type;
+ if (ty->id() == Type::type::LARGE_BINARY || ty->id() == Type::type::LARGE_STRING) {
+ offset_type = int64();
+ } else {
+ offset_type = int32();
+ }
+ DCHECK_OK(func->AddKernel({ty}, offset_type,
+ GenerateTypeAgnosticVarBinaryBase<CountSubstringExec>(ty),
+ MatchSubstringState::Init));
+ }
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
// Slicing
template <typename Type, typename Derived>
@@ -3213,6 +3280,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) {
AddUtf8Length(registry);
AddMatchSubstring(registry);
AddFindSubstring(registry);
+ AddCountSubstring(registry);
MakeUnaryStringBatchKernelWithState<ReplaceSubStringPlain>(
"replace_substring", registry, &replace_substring_doc,
MemAllocation::NO_PREALLOCATE);
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
index bd9dba2..9b4cef4 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
@@ -103,6 +103,25 @@ TYPED_TEST(TestBinaryKernels, FindSubstring) {
"[0, 0, null]", &options_empty);
}
+TYPED_TEST(TestBinaryKernels, CountSubstring) {
+ MatchSubstringOptions options{"aba"};
+ this->CheckUnary("count_substring", "[]", this->offset_type(), "[]", &options);
+ this->CheckUnary(
+ "count_substring",
+ R"(["", null, "ab", "aba", "baba", "ababa", "abaaba", "babacaba", "ABA"])",
+ this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 0]", &options);
+
+ MatchSubstringOptions options_empty{""};
+ this->CheckUnary("count_substring", R"(["", null, "abc"])", this->offset_type(),
+ "[1, null, 4]", &options_empty);
+
+ MatchSubstringOptions options_repeated{"aaa"};
+ this->CheckUnary("count_substring", R"(["", "aaaa", "aaaaa", "aaaaaa", "aaĆ”"])",
+ this->offset_type(), "[0, 1, 1, 2, 0]", &options_repeated);
+
+ // TODO: case-insensitive
+}
+
template <typename TestType>
class TestStringKernels : public BaseTestStringKernels<TestType> {};
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 02c8fb3..434d4a2 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -561,45 +561,51 @@ Containment tests
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
| Function name | Arity | Input types | Output type | Options class |
+===========================+============+====================================+====================+========================================+
-| find_substring | Unary | String-like | Int32 or Int64 (1) | :struct:`MatchSubstringOptions` |
+| count_substring | Unary | String-like | Int32 or Int64 (1) | :struct:`MatchSubstringOptions` |
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
-| match_like | Unary | String-like | Boolean (2) | :struct:`MatchSubstringOptions` |
+| find_substring | Unary | String-like | Int32 or Int64 (2) | :struct:`MatchSubstringOptions` |
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
-| match_substring | Unary | String-like | Boolean (3) | :struct:`MatchSubstringOptions` |
+| match_like | Unary | String-like | Boolean (3) | :struct:`MatchSubstringOptions` |
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
-| match_substring_regex | Unary | String-like | Boolean (4) | :struct:`MatchSubstringOptions` |
+| match_substring | Unary | String-like | Boolean (4) | :struct:`MatchSubstringOptions` |
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
-| index_in | Unary | Boolean, Null, Numeric, Temporal, | Int32 (5) | :struct:`SetLookupOptions` |
+| match_substring_regex | Unary | String-like | Boolean (5) | :struct:`MatchSubstringOptions` |
++---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
+| index_in | Unary | Boolean, Null, Numeric, Temporal, | Int32 (6) | :struct:`SetLookupOptions` |
| | | Binary- and String-like | | |
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
-| is_in | Unary | Boolean, Null, Numeric, Temporal, | Boolean (6) | :struct:`SetLookupOptions` |
+| is_in | Unary | Boolean, Null, Numeric, Temporal, | Boolean (7) | :struct:`SetLookupOptions` |
| | | Binary- and String-like | | |
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
+* \(1) Output is the number of occurrences of
+ :member:`MatchSubstringOptions::pattern` in the corresponding input
+ string. Output type is Int32 for Binary/String, Int64
+ for LargeBinary/LargeString.
-* \(1) Output is the index of the first occurrence of
+* \(2) Output is the index of the first occurrence of
:member:`MatchSubstringOptions::pattern` in the corresponding input
string, otherwise -1. Output type is Int32 for Binary/String, Int64
for LargeBinary/LargeString.
-* \(2) Output is true iff the SQL-style LIKE pattern
+* \(3) Output is true iff the SQL-style LIKE pattern
:member:`MatchSubstringOptions::pattern` fully matches the
corresponding input element. That is, ``%`` will match any number of
characters, ``_`` will match exactly one character, and any other
character matches itself. To match a literal percent sign or
underscore, precede the character with a backslash.
-* \(3) Output is true iff :member:`MatchSubstringOptions::pattern`
+* \(4) Output is true iff :member:`MatchSubstringOptions::pattern`
is a substring of the corresponding input element.
-* \(4) Output is true iff :member:`MatchSubstringOptions::pattern`
+* \(5) Output is true iff :member:`MatchSubstringOptions::pattern`
matches the corresponding input element at any position.
-* \(5) Output is the index of the corresponding input element in
+* \(6) Output is the index of the corresponding input element in
:member:`SetLookupOptions::value_set`, if found there. Otherwise,
output is null.
-* \(6) Output is true iff the corresponding input element is equal to one
+* \(7) Output is true iff the corresponding input element is equal to one
of the elements in :member:`SetLookupOptions::value_set`.
diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst
index ccd5300..a586f90 100644
--- a/docs/source/python/api/compute.rst
+++ b/docs/source/python/api/compute.rst
@@ -178,6 +178,7 @@ Containment tests
.. autosummary::
:toctree: ../generated/
+ count_substring
find_substring
index_in
is_in
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index eb66f44..8dc7181 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -291,6 +291,25 @@ def cast(arr, target_type, safe=True):
return call_function("cast", [arr], options)
+def count_substring(array, pattern):
+ """
+ Count the occurrences of substring *pattern* in each value of a
+ string array.
+
+ Parameters
+ ----------
+ array : pyarrow.Array or pyarrow.ChunkedArray
+ pattern : str
+ pattern to search for exact matches
+
+ Returns
+ -------
+ result : pyarrow.Array or pyarrow.ChunkedArray
+ """
+ return call_function("count_substring", [array],
+ MatchSubstringOptions(pattern))
+
+
def find_substring(array, pattern):
"""
Find the index of the first occurrence of substring *pattern* in each
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index a78be20..64d5ad0 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -285,6 +285,19 @@ def test_variance():
assert pc.variance(data, ddof=1).as_py() == 6.0
+def test_count_substring():
+ arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None])
+ result = pc.count_substring(arr, "ab")
+ expected = pa.array([1, 1, 2, 0, 0, None], type=pa.int32())
+ assert expected.equals(result)
+
+ arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None],
+ type=pa.large_string())
+ result = pc.count_substring(arr, "ab")
+ expected = pa.array([1, 1, 2, 0, 0, None], type=pa.int64())
+ assert expected.equals(result)
+
+
def test_find_substring():
arr = pa.array(["ab", "cab", "ba", None])
result = pc.find_substring(arr, "ab")