You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2020/10/07 12:01:39 UTC

[GitHub] [arrow] jorisvandenbossche commented on a change in pull request #8271: ARROW-9991: [C++] split kernels for strings/binary

jorisvandenbossche commented on a change in pull request #8271:
URL: https://github.com/apache/arrow/pull/8271#discussion_r500954601



##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -809,6 +809,475 @@ struct IsUpperAscii : CharacterPredicateAscii<IsUpperAscii> {
   }
 };
 
+// splitting
+
+template <typename Type, typename ListType, typename Options, typename Derived>
+struct SplitBaseTransform {
+  // TODO: assert offsets types are the same?
+  using offset_type = typename Type::offset_type;
+  using ArrayType = typename TypeTraits<Type>::ArrayType;
+  using ArrayListType = typename TypeTraits<ListType>::ArrayType;
+  using ListScalarType = typename TypeTraits<ListType>::ScalarType;
+  using ScalarType = typename TypeTraits<Type>::ScalarType;
+  using Builder = typename TypeTraits<Type>::BuilderType;
+  using State = OptionsWrapper<Options>;
+
+  static void Split(const uint8_t* input_string, offset_type input_string_nbytes,
+                    offset_type** output_string_offsets, offset_type* string_output_count,
+                    offset_type* string_output_offset, uint8_t** output_string_data,
+                    const Options& options) {
+    const uint8_t* begin = input_string;
+    const uint8_t* end = begin + input_string_nbytes;
+
+    int64_t max_splits = options.max_splits;
+    // if there is no max splits, reversing does not make sense (and is probably less
+    // efficient), but is useful for testing
+    if (options.reverse) {
+      // note that i points 1 further than the 'current'
+      const uint8_t* i = end;
+      // we will record the parts in reverse order
+      std::vector<std::pair<const uint8_t*, const uint8_t*>> parts;
+      if (max_splits > -1) {
+        parts.reserve(max_splits + 1);
+      }
+      while (max_splits != 0) {
+        const uint8_t *separator_begin, *separator_end;
+        // find with whatever algo the part we will 'cut out'
+        if (Derived::FindReverse(begin, i, &separator_begin, &separator_end, options)) {
+          parts.emplace_back(separator_end, i);
+          i = separator_begin;
+          max_splits--;
+        } else {
+          // if we cannot find a separator, we're done
+          break;
+        }
+      }
+      parts.emplace_back(begin, i);
+      // now we do the copying
+      for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
+        auto part = *it;
+        // copy the string data
+        for (auto j = part.first; j < part.second; j++) {
+          *(*output_string_data)++ = *j;
+          (*string_output_offset)++;
+        }
+        // write out the string entry (offset)
+        *(*output_string_offsets)++ = *string_output_offset;
+        (*string_output_count)++;
+      }
+    } else {
+      const uint8_t* i = begin;
+      while (max_splits != 0) {
+        const uint8_t *separator_begin, *separator_end;
+        // find with whatever algo the part we will 'cut out'
+        if (Derived::Find(i, end, &separator_begin, &separator_end, options)) {
+          // copy the part till the beginning of the 'cut'
+          while (i < separator_begin) {
+            *(*output_string_data)++ = *i++;
+            (*string_output_offset)++;
+          }
+          // 'finish' the string by writing the offset
+          *(*output_string_offsets)++ = *string_output_offset;
+          (*string_output_count)++;
+          // jump of the part we cut out
+          i = separator_end;
+          max_splits--;
+        } else {
+          // if we cannot find a separator, we're done
+          break;
+        }
+      }
+      // copy bytes after the pattern
+      while (i < end) {
+        *(*output_string_data)++ = *i++;
+        (*string_output_offset)++;
+      }
+      // and write out the trailing part (can be an empty string)
+      *(*output_string_offsets)++ = *string_output_offset;
+      (*string_output_count)++;
+    }
+  }
+  static Status CheckOptions(const Options& options) { return Status::OK(); }
+  static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+    EnsureLookupTablesFilled();  // only needed for unicode
+    Options options = State::Get(ctx);
+    KERNEL_RETURN_IF_ERROR(ctx, Derived::CheckOptions(options));
+
+    if (batch[0].kind() == Datum::ARRAY) {
+      const ArrayData& input = *batch[0].array();
+      ArrayType input_boxed(batch[0].array());
+      ArrayData* output_list_data = out->mutable_array();
+
+      offset_type input_nbytes = input_boxed.total_values_length();
+      offset_type input_nstrings = static_cast<offset_type>(input.length);
+
+      offset_type output_nbytes_max = input_nbytes;
+      int64_t output_nstrings_max = Derived::CalculateMaxSplits(input_boxed, options);
+      if (output_nstrings_max > std::numeric_limits<offset_type>::max()) {
+        ctx->SetStatus(
+            Status::CapacityError("Result might not fit in a 32bit list or string array, "
+                                  "convert to large_utf8"));
+        return;
+      }
+
+      // Why is the offset buffer not preallocated?
+      KERNEL_ASSIGN_OR_RAISE(output_list_data->buffers[1], ctx,
+                             ctx->Allocate((input_nstrings + 1) * sizeof(offset_type)));
+      offset_type* output_list_offsets =
+          output_list_data->GetMutableValues<offset_type>(1);
+
+      // allocate output string array data
+      // this is a bit low level, should we use a builder?
+      KERNEL_ASSIGN_OR_RAISE(auto buffer_output_string_data, ctx,
+                             ctx->Allocate((output_nbytes_max)));
+      KERNEL_ASSIGN_OR_RAISE(auto buffer_output_string_offsets, ctx,
+                             ctx->Allocate(output_nstrings_max * sizeof(offset_type)));
+
+      offset_type* output_string_offsets =
+          reinterpret_cast<offset_type*>(buffer_output_string_offsets->mutable_data());
+      uint8_t* output_string_data = buffer_output_string_data->mutable_data();
+      // TypedBufferBuilder<bool> null_builder;  // TODO: should we pass a mem poool?
+      // KERNEL_RETURN_IF_ERROR(ctx, null_builder.Reserve(output_nstrings_max));
+
+      // FirstTimeBitmapWriter bitmap_writer(buffer_output_string_bitmap->mutable_data(),
+      //                                     /*offset=*/0, output_nstrings_max);
+
+      // the output offset goes slightly slower then the input (due to skipping pattern)
+      offset_type output_string_offset = 0;
+      offset_type output_string_count = 0;
+
+      // we always start at the beginning
+      *output_list_offsets++ = output_string_count;
+      *output_string_offsets++ = output_string_offset;
+
+      for (int64_t i = 0; i < input_nstrings; i++) {
+        offset_type input_string_nbytes;
+        const uint8_t* input_string = input_boxed.GetValue(i, &input_string_nbytes);
+        // if (input_boxed.IsValid(i)) {
+        Split(input_string, input_string_nbytes, &output_string_offsets,
+              &output_string_count, &output_string_offset, &output_string_data, options);
+        //   null_builder.UnsafeAppend(true);
+        // } else {
+        //   null_builder.UnsafeAppend(false);
+        // }
+
+        *output_list_offsets++ = output_string_count;
+      }
+      // bitmap_writer.Finish();
+      // trim off extra memory usage
+      KERNEL_RETURN_IF_ERROR(ctx,
+                             buffer_output_string_data->Resize(output_string_offsets[-1],
+                                                               /*shrink_to_fit=*/true));
+      KERNEL_RETURN_IF_ERROR(ctx, buffer_output_string_offsets->Resize(
+                                      (output_string_count + 1) * sizeof(offset_type),
+                                      /*shrink_to_fit=*/true));
+      // TODO: how to truncate the bitmap?
+      auto output_string_null_count = 0;  // null_builder.false_count();
+      std::shared_ptr<Buffer> buffer_output_string_bitmap = nullptr;
+      // KERNEL_RETURN_IF_ERROR(ctx, null_builder.Finish(&buffer_output_string_bitmap));
+
+      std::shared_ptr<ArrayData> output_string_array =
+          ArrayData::Make(input.type, output_string_count,
+                          {buffer_output_string_bitmap, buffer_output_string_offsets,
+                           buffer_output_string_data},
+                          output_string_null_count);
+      output_list_data->child_data.push_back(output_string_array);
+
+    } else {
+      const auto& input = checked_cast<const ScalarType&>(*batch[0].scalar());
+      auto result = checked_pointer_cast<ListScalarType>(MakeNullScalar(out->type()));
+      if (input.is_valid) {
+        result->is_valid = true;
+        offset_type input_nbytes = static_cast<offset_type>(input.value->size());
+
+        offset_type output_nbytes_max = input_nbytes;
+
+        int64_t output_nstrings_max = Derived::CalculateMaxSplits(input, options);
+        if (output_nstrings_max > std::numeric_limits<offset_type>::max()) {
+          ctx->SetStatus(Status::CapacityError(
+              "Result might not fit in a 32bit list or string array, "
+              "convert to large_utf8"));
+          return;
+        }
+
+        KERNEL_ASSIGN_OR_RAISE(auto buffer_output_string_data, ctx,
+                               ctx->Allocate((output_nbytes_max)));
+        KERNEL_ASSIGN_OR_RAISE(auto buffer_output_string_offsets, ctx,
+                               ctx->Allocate(output_nstrings_max * sizeof(offset_type)));
+        offset_type* output_string_offsets =
+            reinterpret_cast<offset_type*>(buffer_output_string_offsets->mutable_data());
+        uint8_t* output_string_data = buffer_output_string_data->mutable_data();
+
+        // the output offset goes slightly slower then the input (due to skipping pattern)
+        offset_type string_output_offset = 0;
+        offset_type string_output_count = 0;
+
+        // we always start at the beginning
+        *output_string_offsets++ = string_output_offset;
+
+        const uint8_t* input_string = input.value->data();
+        Split(input_string, input_nbytes, &output_string_offsets, &string_output_count,
+              &string_output_offset, &output_string_data, options);
+
+        KERNEL_RETURN_IF_ERROR(
+            ctx, buffer_output_string_data->Resize(output_string_offsets[-1],
+                                                   /*shrink_to_fit=*/true));
+        KERNEL_RETURN_IF_ERROR(ctx, buffer_output_string_offsets->Resize(
+                                        (string_output_count + 1) * sizeof(offset_type),
+                                        /*shrink_to_fit=*/true));
+
+        std::shared_ptr<ArrayData> output_string_array = ArrayData::Make(
+            input.type, string_output_count,
+            {nullptr, buffer_output_string_offsets, buffer_output_string_data});
+        result->value = std::make_shared<ArrayType>(output_string_array);
+      }
+      out->value = result;
+    }
+  }
+};
+
+template <typename Type, typename ListType>
+struct SplitPatternTransform : SplitBaseTransform<Type, ListType, SplitPatternOptions,
+                                                  SplitPatternTransform<Type, ListType>> {
+  using ArrayType = typename TypeTraits<Type>::ArrayType;
+  using ScalarType = typename TypeTraits<Type>::ScalarType;
+  using offset_type = typename Type::offset_type;
+  static Status CheckOptions(const SplitPatternOptions& options) {
+    if (options.pattern.length() == 0) {
+      return Status::Invalid("Empty separator");
+    }
+    return Status::OK();
+  }
+  static bool Find(const uint8_t* begin, const uint8_t* end,
+                   const uint8_t** separator_begin, const uint8_t** separator_end,
+                   const SplitPatternOptions& options) {
+    const uint8_t* pattern = reinterpret_cast<const uint8_t*>(options.pattern.c_str());
+    const int64_t pattern_length = options.pattern.length();
+    const uint8_t* i = begin;
+    // this is O(n*m) complexity, we could use the Knuth-Morris-Pratt algorithm used in
+    // the match kernel
+    while ((i + pattern_length <= end)) {
+      i = std::search(i, end, pattern, pattern + pattern_length);
+      if (i != end) {
+        *separator_begin = i;
+        *separator_end = i + pattern_length;
+        return true;
+      }
+    }
+    return false;
+  }
+  static bool FindReverse(const uint8_t* begin, const uint8_t* end,
+                          const uint8_t** separator_begin, const uint8_t** separator_end,
+                          const SplitPatternOptions& options) {
+    const uint8_t* pattern = reinterpret_cast<const uint8_t*>(options.pattern.c_str());
+    const int64_t pattern_length = options.pattern.length();
+    // this is O(n*m) complexity, we could use the Knuth-Morris-Pratt algorithm used in
+    // the match kernel
+    std::reverse_iterator<const uint8_t*> ri(end);
+    std::reverse_iterator<const uint8_t*> rend(begin);
+    std::reverse_iterator<const uint8_t*> pattern_rbegin(pattern + pattern_length);
+    std::reverse_iterator<const uint8_t*> pattern_rend(pattern);
+    while (begin <= ri.base() - pattern_length) {
+      ri = std::search(ri, rend, pattern_rbegin, pattern_rend);
+      if (ri != rend) {
+        *separator_begin = ri.base() - pattern_length;
+        *separator_end = ri.base();
+        return true;
+      }
+    }
+    return false;
+  }
+  static int64_t CalculateMaxSplits(const ArrayType& input,
+                                    const SplitPatternOptions& options) {
+    // Worst case is e.g. ['  ', ' '] split by ' ' -> [['', '', ''], ['', '']]
+    // i.e. the length of each string divided by the pattern length + 1
+    // This can double the amount of strings, thus not fit into a (32bit) list or string
+    // anymore
+    int64_t output_nstrings_max = 0;
+    for (offset_type i = 0; i < input.length(); i++) {
+      output_nstrings_max += 1 + input.value_length(i) / options.pattern.length();
+    }
+    return output_nstrings_max;
+  }
+  static int64_t CalculateMaxSplits(const ScalarType& input,
+                                    const SplitPatternOptions& options) {
+    // Worst case is e.g. ['  ', ' '] split by ' ' -> [['', '', ''], ['', '']]
+    // i.e. the length of each string divided by the pattern length + 1
+    // This can double the amount of strings, thus not fit into a (32bit) list or string
+    // anymore
+    return 1 + input.value->size() / options.pattern.length();
+  }
+};
+
+void AddSplitPattern(FunctionRegistry* registry) {
+  auto func = std::make_shared<ScalarFunction>("split_pattern", Arity::Unary());

Review comment:
       Then `binary_split_pattern` seems "most correct"? Although I personally don't really like the fact that this naming scheme "hides" useful string functionality behind the `binary_` prefix (which is the same with `binary_length`, though)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org