You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/06/26 11:35:30 UTC

[arrow-rs] branch master updated: unify substring for binary&utf8 (#4442)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 8e65b5803 unify substring for binary&utf8 (#4442)
8e65b5803 is described below

commit 8e65b5803dd6c457e18b24c13dac9a13bcc4d4cf
Author: jakevin <ja...@gmail.com>
AuthorDate: Mon Jun 26 19:35:24 2023 +0800

    unify substring for binary&utf8 (#4442)
---
 arrow-string/Cargo.toml       |   1 +
 arrow-string/src/substring.rs | 171 ++++++++++++++++--------------------------
 2 files changed, 67 insertions(+), 105 deletions(-)

diff --git a/arrow-string/Cargo.toml b/arrow-string/Cargo.toml
index 6e16e0163..0f88ffbac 100644
--- a/arrow-string/Cargo.toml
+++ b/arrow-string/Cargo.toml
@@ -41,6 +41,7 @@ arrow-array = { workspace = true }
 arrow-select = { workspace = true }
 regex = { version = "1.7.0", default-features = false, features = ["std", "unicode", "perf"] }
 regex-syntax = { version = "0.7.1", default-features = false, features = ["unicode"] }
+num = { version = "0.4", default-features = false, features = ["std"] }
 
 [package.metadata.docs.rs]
 all-features = true
diff --git a/arrow-string/src/substring.rs b/arrow-string/src/substring.rs
index a8250c75d..1075d1069 100644
--- a/arrow-string/src/substring.rs
+++ b/arrow-string/src/substring.rs
@@ -25,6 +25,7 @@ use arrow_array::*;
 use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer};
 use arrow_data::ArrayData;
 use arrow_schema::{ArrowError, DataType};
+use num::Zero;
 use std::cmp::Ordering;
 use std::sync::Arc;
 
@@ -106,7 +107,7 @@ pub fn substring(
                 UInt64: UInt64Type
             )
         }
-        DataType::LargeBinary => binary_substring(
+        DataType::LargeBinary => byte_substring(
             array
                 .as_any()
                 .downcast_ref::<LargeBinaryArray>()
@@ -114,7 +115,7 @@ pub fn substring(
             start,
             length.map(|e| e as i64),
         ),
-        DataType::Binary => binary_substring(
+        DataType::Binary => byte_substring(
             array
                 .as_any()
                 .downcast_ref::<BinaryArray>()
@@ -131,7 +132,7 @@ pub fn substring(
             start as i32,
             length.map(|e| e as i32),
         ),
-        DataType::LargeUtf8 => utf8_substring(
+        DataType::LargeUtf8 => byte_substring(
             array
                 .as_any()
                 .downcast_ref::<LargeStringArray>()
@@ -139,7 +140,7 @@ pub fn substring(
             start,
             length.map(|e| e as i64),
         ),
-        DataType::Utf8 => utf8_substring(
+        DataType::Utf8 => byte_substring(
             array
                 .as_any()
                 .downcast_ref::<StringArray>()
@@ -246,36 +247,61 @@ fn get_start_end_offset(
     (start_offset, end_offset)
 }
 
-fn binary_substring<OffsetSize: OffsetSizeTrait>(
-    array: &GenericBinaryArray<OffsetSize>,
-    start: OffsetSize,
-    length: Option<OffsetSize>,
-) -> Result<ArrayRef, ArrowError> {
+fn byte_substring<T: ByteArrayType>(
+    array: &GenericByteArray<T>,
+    start: T::Offset,
+    length: Option<T::Offset>,
+) -> Result<ArrayRef, ArrowError>
+where
+    <T as ByteArrayType>::Native: PartialEq,
+{
     let offsets = array.value_offsets();
     let data = array.value_data();
-    let zero = OffsetSize::zero();
+    let zero = <T::Offset as Zero>::zero();
+
+    // When array is [Large]StringArray, we will check whether `offset` is at a valid char boundary.
+    let check_char_boundary = {
+        |offset: T::Offset| {
+            if !matches!(T::DATA_TYPE, DataType::Utf8 | DataType::LargeUtf8) {
+                return Ok(offset);
+            }
+            // Safety: a StringArray must contain valid UTF8 data
+            let data_str = unsafe { std::str::from_utf8_unchecked(data) };
+            let offset_usize = offset.as_usize();
+            if data_str.is_char_boundary(offset_usize) {
+                Ok(offset)
+            } else {
+                Err(ArrowError::ComputeError(format!(
+                    "The offset {offset_usize} is at an invalid utf-8 boundary."
+                )))
+            }
+        }
+    };
 
     // start and end offsets of all substrings
-    let mut new_starts_ends: Vec<(OffsetSize, OffsetSize)> =
+    let mut new_starts_ends: Vec<(T::Offset, T::Offset)> =
         Vec::with_capacity(array.len());
-    let mut new_offsets: Vec<OffsetSize> = Vec::with_capacity(array.len() + 1);
+    let mut new_offsets: Vec<T::Offset> = Vec::with_capacity(array.len() + 1);
     let mut len_so_far = zero;
     new_offsets.push(zero);
 
-    offsets.windows(2).for_each(|pair| {
-        let new_start = match start.cmp(&zero) {
-            Ordering::Greater => (pair[0] + start).min(pair[1]),
-            Ordering::Equal => pair[0],
-            Ordering::Less => (pair[1] + start).max(pair[0]),
-        };
-        let new_end = match length {
-            Some(length) => (length + new_start).min(pair[1]),
-            None => pair[1],
-        };
-        len_so_far += new_end - new_start;
-        new_starts_ends.push((new_start, new_end));
-        new_offsets.push(len_so_far);
-    });
+    offsets
+        .windows(2)
+        .try_for_each(|pair| -> Result<(), ArrowError> {
+            let new_start = match start.cmp(&zero) {
+                Ordering::Greater => check_char_boundary((pair[0] + start).min(pair[1]))?,
+                Ordering::Equal => pair[0],
+                Ordering::Less => check_char_boundary((pair[1] + start).max(pair[0]))?,
+            };
+            let new_end = match length {
+                Some(length) => check_char_boundary((length + new_start).min(pair[1]))?,
+                None => pair[1],
+            };
+            len_so_far += new_end - new_start;
+            new_starts_ends.push((new_start, new_end));
+            new_offsets.push(len_so_far);
+            Ok(())
+        })?;
 
     // concatenate substrings into a buffer
     let mut new_values = MutableBuffer::new(new_offsets.last().unwrap().as_usize());
@@ -291,7 +317,7 @@ fn binary_substring<OffsetSize: OffsetSizeTrait>(
 
     let data = unsafe {
         ArrayData::new_unchecked(
-            GenericBinaryArray::<OffsetSize>::DATA_TYPE,
+            GenericByteArray::<T>::DATA_TYPE,
             array.len(),
             None,
             array.nulls().map(|b| b.inner().sliced()),
@@ -349,84 +375,6 @@ fn fixed_size_binary_substring(
     Ok(make_array(array_data))
 }
 
-/// substring by byte
-fn utf8_substring<OffsetSize: OffsetSizeTrait>(
-    array: &GenericStringArray<OffsetSize>,
-    start: OffsetSize,
-    length: Option<OffsetSize>,
-) -> Result<ArrayRef, ArrowError> {
-    let offsets = array.value_offsets();
-    let data = array.value_data();
-    let zero = OffsetSize::zero();
-
-    // Check if `offset` is at a valid char boundary.
-    // If yes, return `offset`, else return error
-    let check_char_boundary = {
-        // Safety: a StringArray must contain valid UTF8 data
-        let data_str = unsafe { std::str::from_utf8_unchecked(data) };
-        |offset: OffsetSize| {
-            let offset_usize = offset.as_usize();
-            if data_str.is_char_boundary(offset_usize) {
-                Ok(offset)
-            } else {
-                Err(ArrowError::ComputeError(format!(
-                    "The offset {offset_usize} is at an invalid utf-8 boundary."
-                )))
-            }
-        }
-    };
-
-    // start and end offsets of all substrings
-    let mut new_starts_ends: Vec<(OffsetSize, OffsetSize)> =
-        Vec::with_capacity(array.len());
-    let mut new_offsets: Vec<OffsetSize> = Vec::with_capacity(array.len() + 1);
-    let mut len_so_far = zero;
-    new_offsets.push(zero);
-
-    offsets
-        .windows(2)
-        .try_for_each(|pair| -> Result<(), ArrowError> {
-            let new_start = match start.cmp(&zero) {
-                Ordering::Greater => check_char_boundary((pair[0] + start).min(pair[1]))?,
-                Ordering::Equal => pair[0],
-                Ordering::Less => check_char_boundary((pair[1] + start).max(pair[0]))?,
-            };
-            let new_end = match length {
-                Some(length) => check_char_boundary((length + new_start).min(pair[1]))?,
-                None => pair[1],
-            };
-            len_so_far += new_end - new_start;
-            new_starts_ends.push((new_start, new_end));
-            new_offsets.push(len_so_far);
-            Ok(())
-        })?;
-
-    // concatenate substrings into a buffer
-    let mut new_values = MutableBuffer::new(new_offsets.last().unwrap().as_usize());
-
-    new_starts_ends
-        .iter()
-        .map(|(start, end)| {
-            let start = start.as_usize();
-            let end = end.as_usize();
-            &data[start..end]
-        })
-        .for_each(|slice| new_values.extend_from_slice(slice));
-
-    let data = unsafe {
-        ArrayData::new_unchecked(
-            GenericStringArray::<OffsetSize>::DATA_TYPE,
-            array.len(),
-            None,
-            array.nulls().map(|b| b.inner().sliced()),
-            0,
-            vec![Buffer::from_vec(new_offsets), new_values.into()],
-            vec![],
-        )
-    };
-    Ok(make_array(data))
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -1020,4 +968,17 @@ mod tests {
         let err = substring(&array, 0, Some(5)).unwrap_err().to_string();
         assert!(err.contains("invalid utf-8 boundary"));
     }
+
+    #[test]
+    fn non_utf8_bytes() {
+        // non-utf8 bytes
+        let bytes: &[u8] = &[0xE4, 0xBD, 0xA0, 0xE5, 0xA5, 0xBD, 0xE8, 0xAF, 0xAD];
+        let array = BinaryArray::from(vec![Some(bytes)]);
+        let arr = substring(&array, 0, Some(5)).unwrap();
+        let actual = arr.as_any().downcast_ref::<BinaryArray>().unwrap();
+
+        let expected_bytes: &[u8] = &[0xE4, 0xBD, 0xA0, 0xE5, 0xA5];
+        let expected = BinaryArray::from(vec![Some(expected_bytes)]);
+        assert_eq!(expected, *actual);
+    }
 }