You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/04/19 16:26:41 UTC

[arrow-rs] branch master updated: Add utf-8 validation checking for `substring` (#1577)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 43f4e16e5 Add utf-8 validation checking for `substring` (#1577)
43f4e16e5 is described below

commit 43f4e16e5f9543e1bcf0c625b322785320993fc9
Author: Remzi Yang <59...@users.noreply.github.com>
AuthorDate: Wed Apr 20 00:26:34 2022 +0800

    Add utf-8 validation checking for `substring` (#1577)
    
    * add utf-8 validation checking
    update doc
    add a test for invalid array type
    
    Signed-off-by: remzi <13...@gmail.com>
    
    * add tests
    clean up
    
    Signed-off-by: remzi <13...@gmail.com>
    
    * test the worst case
    
    Signed-off-by: remzi <13...@gmail.com>
    
    * update doc and tests
    
    Signed-off-by: remzi <13...@gmail.com>
    
    * update doc
    
    Signed-off-by: remzi <13...@gmail.com>
    
    * use std method is_char_boundary
    update doc
    
    Signed-off-by: remzi <13...@gmail.com>
    
    * add 2 substring benches
    
    Signed-off-by: remzi <13...@gmail.com>
    
    * replace dyn Fn with loop unswitching
    
    Signed-off-by: remzi <13...@gmail.com>
    
    * Update arrow/src/compute/kernels/substring.rs
    
    Co-authored-by: Raphael Taylor-Davies <17...@users.noreply.github.com>
    
    Co-authored-by: Raphael Taylor-Davies <17...@users.noreply.github.com>
---
 arrow/benches/string_kernels.rs        |  13 ++--
 arrow/src/compute/kernels/substring.rs | 107 +++++++++++++++++++++------------
 2 files changed, 77 insertions(+), 43 deletions(-)

diff --git a/arrow/benches/string_kernels.rs b/arrow/benches/string_kernels.rs
index c91801b15..37d1f3f89 100644
--- a/arrow/benches/string_kernels.rs
+++ b/arrow/benches/string_kernels.rs
@@ -25,8 +25,8 @@ use arrow::array::*;
 use arrow::compute::kernels::substring::substring;
 use arrow::util::bench_util::*;
 
-fn bench_substring(arr: &StringArray, start: i64, length: usize) {
-    substring(criterion::black_box(arr), start, Some(length as u64)).unwrap();
+fn bench_substring(arr: &StringArray, start: i64, length: Option<u64>) {
+    substring(criterion::black_box(arr), start, length).unwrap();
 }
 
 fn add_benchmark(c: &mut Criterion) {
@@ -34,10 +34,13 @@ fn add_benchmark(c: &mut Criterion) {
     let str_len = 1000;
 
     let arr_string = create_string_array_with_len::<i32>(size, 0.0, str_len);
-    let start = 0;
 
-    c.bench_function("substring", |b| {
-        b.iter(|| bench_substring(&arr_string, start, str_len))
+    c.bench_function("substring (start = 0, length = None)", |b| {
+        b.iter(|| bench_substring(&arr_string, 0, None))
+    });
+
+    c.bench_function("substring (start = 1, length = str_len - 1)", |b| {
+        b.iter(|| bench_substring(&arr_string, 1, Some((str_len - 1) as u64)))
     });
 }
 
diff --git a/arrow/src/compute/kernels/substring.rs b/arrow/src/compute/kernels/substring.rs
index 647491c72..df05a73c0 100644
--- a/arrow/src/compute/kernels/substring.rs
+++ b/arrow/src/compute/kernels/substring.rs
@@ -23,6 +23,7 @@ use crate::{
     datatypes::DataType,
     error::{ArrowError, Result},
 };
+use std::cmp::Ordering;
 
 fn generic_substring<OffsetSize: StringOffsetSizeTrait>(
     array: &GenericStringArray<OffsetSize>,
@@ -35,37 +36,46 @@ fn generic_substring<OffsetSize: StringOffsetSizeTrait>(
     let data = values.as_slice();
     let zero = OffsetSize::zero();
 
-    let cal_new_start: Box<dyn Fn(OffsetSize, OffsetSize) -> OffsetSize> = if start
-        >= zero
-    {
-        // count from the start of string
-        Box::new(|old_start: OffsetSize, end: OffsetSize| (old_start + start).min(end))
-    } else {
-        // count from the end of string
-        Box::new(|old_start: OffsetSize, end: OffsetSize| (end + start).max(old_start))
+    // Check if `offset` is at a valid char boundary.
+    // If yes, return `offset`, else return error
+    let check_char_boundary = {
+        // Safety: a StringArray must contain valid UTF8 data
+        let data_str = unsafe { std::str::from_utf8_unchecked(data) };
+        |offset: OffsetSize| {
+            let offset_usize = offset.to_usize().unwrap();
+            if data_str.is_char_boundary(offset_usize) {
+                Ok(offset)
+            } else {
+                Err(ArrowError::ComputeError(format!(
+                    "The offset {} is at an invalid utf-8 boundary.",
+                    offset_usize
+                )))
+            }
+        }
     };
 
-    let cal_new_length: Box<dyn Fn(OffsetSize, OffsetSize) -> OffsetSize> =
-        if let Some(length) = length {
-            Box::new(move |start: OffsetSize, end: OffsetSize| length.min(end - start))
-        } else {
-            Box::new(move |start: OffsetSize, end: OffsetSize| end - start)
-        };
-
-    // start and end offsets for each substring
+    // start and end offsets of all substrings
     let mut new_starts_ends: Vec<(OffsetSize, OffsetSize)> =
         Vec::with_capacity(array.len());
     let mut new_offsets: Vec<OffsetSize> = Vec::with_capacity(array.len() + 1);
     let mut len_so_far = zero;
     new_offsets.push(zero);
 
-    offsets.windows(2).for_each(|pair| {
-        let new_start = cal_new_start(pair[0], pair[1]);
-        let new_length = cal_new_length(new_start, pair[1]);
-        len_so_far += new_length;
-        new_starts_ends.push((new_start, new_start + new_length));
+    offsets.windows(2).try_for_each(|pair| -> Result<()> {
+        let new_start = match start.cmp(&zero) {
+            Ordering::Greater => check_char_boundary((pair[0] + start).min(pair[1]))?,
+            Ordering::Equal => pair[0],
+            Ordering::Less => check_char_boundary((pair[1] + start).max(pair[0]))?,
+        };
+        let new_end = match length {
+            Some(length) => check_char_boundary((length + new_start).min(pair[1]))?,
+            None => pair[1],
+        };
+        len_so_far += new_end - new_start;
+        new_starts_ends.push((new_start, new_end));
         new_offsets.push(len_so_far);
-    });
+        Ok(())
+    })?;
 
     // concatenate substrings into a buffer
     let mut new_values =
@@ -107,29 +117,28 @@ fn generic_substring<OffsetSize: StringOffsetSizeTrait>(
 ///
 /// Attention: Both `start` and `length` are counted by byte, not by char.
 ///
-/// # Warning
-///
-/// This function **might** return in invalid utf-8 format if the
-/// character length falls on a non-utf8 boundary, which we
-/// [hope to fix](https://github.com/apache/arrow-rs/issues/1531)
-/// in a future release.
-///
-/// ## Example of getting an invalid substring
+/// # Basic usage
 /// ```
-/// # // Doesn't pass due to  https://github.com/apache/arrow-rs/issues/1531
-/// # #[cfg(not(feature = "force_validate"))]
-/// # {
 /// # use arrow::array::StringArray;
 /// # use arrow::compute::kernels::substring::substring;
-/// let array = StringArray::from(vec![Some("E=mc²")]);
-/// let result = substring(&array, -1, None).unwrap();
+/// let array = StringArray::from(vec![Some("arrow"), None, Some("rust")]);
+/// let result = substring(&array, 1, Some(4)).unwrap();
 /// let result = result.as_any().downcast_ref::<StringArray>().unwrap();
-/// assert_eq!(result.value(0).as_bytes(), &[0x00B2]); // invalid utf-8 format
-/// # }
+/// assert_eq!(result, &StringArray::from(vec![Some("rrow"), None, Some("ust")]));
 /// ```
 ///
 /// # Error
-/// this function errors when the passed array is not a \[Large\]String array.
+/// - The function errors when the passed array is not a \[Large\]String array.
+/// - The function errors if the offset of a substring in the input array is at invalid char boundary.
+///
+/// ## Example of trying to get an invalid utf-8 format substring
+/// ```
+/// # use arrow::array::StringArray;
+/// # use arrow::compute::kernels::substring::substring;
+/// let array = StringArray::from(vec![Some("E=mc²")]);
+/// let error = substring(&array, 0, Some(5)).unwrap_err().to_string();
+/// assert!(error.contains("invalid utf-8 boundary"));
+/// ```
 pub fn substring(array: &dyn Array, start: i64, length: Option<u64>) -> Result<ArrayRef> {
     match array.data_type() {
         DataType::LargeUtf8 => generic_substring(
@@ -304,4 +313,26 @@ mod tests {
     fn without_nulls_large_string() -> Result<()> {
         without_nulls::<LargeStringArray>()
     }
+
+    #[test]
+    fn check_invalid_array_type() {
+        let array = Int32Array::from(vec![Some(1), Some(2), Some(3)]);
+        let err = substring(&array, 0, None).unwrap_err().to_string();
+        assert!(err.contains("substring does not support type"));
+    }
+
+    // tests for the utf-8 validation checking
+    #[test]
+    fn check_start_index() {
+        let array = StringArray::from(vec![Some("E=mc²"), Some("ascii")]);
+        let err = substring(&array, -1, None).unwrap_err().to_string();
+        assert!(err.contains("invalid utf-8 boundary"));
+    }
+
+    #[test]
+    fn check_length() {
+        let array = StringArray::from(vec![Some("E=mc²"), Some("ascii")]);
+        let err = substring(&array, 0, Some(5)).unwrap_err().to_string();
+        assert!(err.contains("invalid utf-8 boundary"));
+    }
 }