You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/08/26 17:47:35 UTC

[arrow-rs] branch active_release updated: Implement `regexp_matches_utf8` (#706) (#717)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch active_release
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/active_release by this push:
     new 446a4b7  Implement `regexp_matches_utf8` (#706) (#717)
446a4b7 is described below

commit 446a4b7ae44eaa48eb34fa46b319a470fa3ec971
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Thu Aug 26 13:47:28 2021 -0400

    Implement `regexp_matches_utf8` (#706) (#717)
    
    * impl regexp_matches_utf8
    
    * fix clippy
    
    * add bench
    
    * optimize
    
    Co-authored-by: baishen <ba...@gmail.com>
---
 arrow/benches/comparison_kernels.rs     |  17 +++
 arrow/src/compute/kernels/comparison.rs | 244 ++++++++++++++++++++++++++++++++
 2 files changed, 261 insertions(+)

diff --git a/arrow/benches/comparison_kernels.rs b/arrow/benches/comparison_kernels.rs
index a3df556..bfee9b9 100644
--- a/arrow/benches/comparison_kernels.rs
+++ b/arrow/benches/comparison_kernels.rs
@@ -119,6 +119,15 @@ fn bench_nlike_utf8_scalar(arr_a: &StringArray, value_b: &str) {
         .unwrap();
 }
 
+fn bench_regexp_is_match_utf8_scalar(arr_a: &StringArray, value_b: &str) {
+    regexp_is_match_utf8_scalar(
+        criterion::black_box(arr_a),
+        criterion::black_box(value_b),
+        None,
+    )
+    .unwrap();
+}
+
 fn add_benchmark(c: &mut Criterion) {
     let size = 65536;
     let arr_a = create_primitive_array_with_seed::<Float32Type>(size, 0.0, 42);
@@ -195,6 +204,14 @@ fn add_benchmark(c: &mut Criterion) {
     c.bench_function("nlike_utf8 scalar complex", |b| {
         b.iter(|| bench_nlike_utf8_scalar(&arr_string, "%xx_xx%xxx"))
     });
+
+    c.bench_function("egexp_matches_utf8 scalar starts with", |b| {
+        b.iter(|| bench_regexp_is_match_utf8_scalar(&arr_string, "^xx"))
+    });
+
+    c.bench_function("egexp_matches_utf8 scalar ends with", |b| {
+        b.iter(|| bench_regexp_is_match_utf8_scalar(&arr_string, "xx$"))
+    });
 }
 
 criterion_group!(benches, add_benchmark);
diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs
index a899d5b..8b7718f 100644
--- a/arrow/src/compute/kernels/comparison.rs
+++ b/arrow/src/compute/kernels/comparison.rs
@@ -450,6 +450,136 @@ pub fn nlike_utf8_scalar<OffsetSize: StringOffsetSizeTrait>(
     Ok(BooleanArray::from(data))
 }
 
+/// Perform SQL `array ~ regex_array` operation on [`StringArray`] / [`LargeStringArray`].
+/// If `regex_array` element has an empty value, the corresponding result value is always true.
+///
+/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] flag, which allow
+/// special search modes, such as case insensitive and multi-line mode.
+/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags)
+/// for more information.
+pub fn regexp_is_match_utf8<OffsetSize: StringOffsetSizeTrait>(
+    array: &GenericStringArray<OffsetSize>,
+    regex_array: &GenericStringArray<OffsetSize>,
+    flags_array: Option<&GenericStringArray<OffsetSize>>,
+) -> Result<BooleanArray> {
+    if array.len() != regex_array.len() {
+        return Err(ArrowError::ComputeError(
+            "Cannot perform comparison operation on arrays of different length"
+                .to_string(),
+        ));
+    }
+    let null_bit_buffer =
+        combine_option_bitmap(array.data_ref(), regex_array.data_ref(), array.len())?;
+
+    let mut patterns: HashMap<String, Regex> = HashMap::new();
+    let mut result = BooleanBufferBuilder::new(array.len());
+
+    let complete_pattern = match flags_array {
+        Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map(
+            |(pattern, flags)| {
+                pattern.map(|pattern| match flags {
+                    Some(flag) => format!("(?{}){}", flag, pattern),
+                    None => pattern.to_string(),
+                })
+            },
+        )) as Box<dyn Iterator<Item = Option<String>>>,
+        None => Box::new(
+            regex_array
+                .iter()
+                .map(|pattern| pattern.map(|pattern| pattern.to_string())),
+        ),
+    };
+
+    array
+        .iter()
+        .zip(complete_pattern)
+        .map(|(value, pattern)| {
+            match (value, pattern) {
+                // Required for Postgres compatibility:
+                // SELECT 'foobarbequebaz' ~ ''); = true
+                (Some(_), Some(pattern)) if pattern == *"" => {
+                    result.append(true);
+                }
+                (Some(value), Some(pattern)) => {
+                    let existing_pattern = patterns.get(&pattern);
+                    let re = match existing_pattern {
+                        Some(re) => re.clone(),
+                        None => {
+                            let re = Regex::new(pattern.as_str()).map_err(|e| {
+                                ArrowError::ComputeError(format!(
+                                    "Regular expression did not compile: {:?}",
+                                    e
+                                ))
+                            })?;
+                            patterns.insert(pattern, re.clone());
+                            re
+                        }
+                    };
+                    result.append(re.is_match(value));
+                }
+                _ => result.append(false),
+            }
+            Ok(())
+        })
+        .collect::<Result<Vec<()>>>()?;
+
+    let data = ArrayData::new(
+        DataType::Boolean,
+        array.len(),
+        None,
+        null_bit_buffer,
+        0,
+        vec![result.finish()],
+        vec![],
+    );
+    Ok(BooleanArray::from(data))
+}
+
+/// Perform SQL `array ~ regex_array` operation on [`StringArray`] /
+/// [`LargeStringArray`] and a scalar.
+///
+/// See the documentation on [`regexp_is_match_utf8`] for more details.
+pub fn regexp_is_match_utf8_scalar<OffsetSize: StringOffsetSizeTrait>(
+    array: &GenericStringArray<OffsetSize>,
+    regex: &str,
+    flag: Option<&str>,
+) -> Result<BooleanArray> {
+    let null_bit_buffer = array.data().null_buffer().cloned();
+    let mut result = BooleanBufferBuilder::new(array.len());
+
+    let pattern = match flag {
+        Some(flag) => format!("(?{}){}", flag, regex),
+        None => regex.to_string(),
+    };
+    if pattern == *"" {
+        for _i in 0..array.len() {
+            result.append(true);
+        }
+    } else {
+        let re = Regex::new(pattern.as_str()).map_err(|e| {
+            ArrowError::ComputeError(format!(
+                "Regular expression did not compile: {:?}",
+                e
+            ))
+        })?;
+        for i in 0..array.len() {
+            let value = array.value(i);
+            result.append(re.is_match(value));
+        }
+    }
+
+    let data = ArrayData::new(
+        DataType::Boolean,
+        array.len(),
+        None,
+        null_bit_buffer,
+        0,
+        vec![result.finish()],
+        vec![],
+    );
+    Ok(BooleanArray::from(data))
+}
+
 /// Perform `left == right` operation on [`StringArray`] / [`LargeStringArray`].
 pub fn eq_utf8<OffsetSize: StringOffsetSizeTrait>(
     left: &GenericStringArray<OffsetSize>,
@@ -1438,6 +1568,82 @@ mod tests {
         };
     }
 
+    macro_rules! test_flag_utf8 {
+        ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => {
+            #[test]
+            fn $test_name() {
+                let left = StringArray::from($left);
+                let right = StringArray::from($right);
+                let res = $op(&left, &right, None).unwrap();
+                let expected = $expected;
+                assert_eq!(expected.len(), res.len());
+                for i in 0..res.len() {
+                    let v = res.value(i);
+                    assert_eq!(v, expected[i]);
+                }
+            }
+        };
+        ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => {
+            #[test]
+            fn $test_name() {
+                let left = StringArray::from($left);
+                let right = StringArray::from($right);
+                let flag = Some(StringArray::from($flag));
+                let res = $op(&left, &right, flag.as_ref()).unwrap();
+                let expected = $expected;
+                assert_eq!(expected.len(), res.len());
+                for i in 0..res.len() {
+                    let v = res.value(i);
+                    assert_eq!(v, expected[i]);
+                }
+            }
+        };
+    }
+
+    macro_rules! test_flag_utf8_scalar {
+        ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => {
+            #[test]
+            fn $test_name() {
+                let left = StringArray::from($left);
+                let res = $op(&left, $right, None).unwrap();
+                let expected = $expected;
+                assert_eq!(expected.len(), res.len());
+                for i in 0..res.len() {
+                    let v = res.value(i);
+                    assert_eq!(
+                        v,
+                        expected[i],
+                        "unexpected result when comparing {} at position {} to {} ",
+                        left.value(i),
+                        i,
+                        $right
+                    );
+                }
+            }
+        };
+        ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => {
+            #[test]
+            fn $test_name() {
+                let left = StringArray::from($left);
+                let flag = Some($flag);
+                let res = $op(&left, $right, flag).unwrap();
+                let expected = $expected;
+                assert_eq!(expected.len(), res.len());
+                for i in 0..res.len() {
+                    let v = res.value(i);
+                    assert_eq!(
+                        v,
+                        expected[i],
+                        "unexpected result when comparing {} at position {} to {} ",
+                        left.value(i),
+                        i,
+                        $right
+                    );
+                }
+            }
+        };
+    }
+
     test_utf8!(
         test_utf8_array_like,
         vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrows", "arrow"],
@@ -1621,4 +1827,42 @@ mod tests {
         gt_eq_utf8_scalar,
         vec![false, false, true, true]
     );
+    test_flag_utf8!(
+        test_utf8_array_regexp_is_match,
+        vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"],
+        vec!["^ar", "^AR", "ow$", "OW$", "foo", ""],
+        regexp_is_match_utf8,
+        vec![true, false, true, false, false, true]
+    );
+    test_flag_utf8!(
+        test_utf8_array_regexp_is_match_insensitive,
+        vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"],
+        vec!["^ar", "^AR", "ow$", "OW$", "foo", ""],
+        vec!["i"; 6],
+        regexp_is_match_utf8,
+        vec![true, true, true, true, false, true]
+    );
+
+    test_flag_utf8_scalar!(
+        test_utf8_array_regexp_is_match_scalar,
+        vec!["arrow", "ARROW", "parquet", "PARQUET"],
+        "^ar",
+        regexp_is_match_utf8_scalar,
+        vec![true, false, false, false]
+    );
+    test_flag_utf8_scalar!(
+        test_utf8_array_regexp_is_match_empty_scalar,
+        vec!["arrow", "ARROW", "parquet", "PARQUET"],
+        "",
+        regexp_is_match_utf8_scalar,
+        vec![true, true, true, true]
+    );
+    test_flag_utf8_scalar!(
+        test_utf8_array_regexp_is_match_insensitive_scalar,
+        vec!["arrow", "ARROW", "parquet", "PARQUET"],
+        "^ar",
+        "i",
+        regexp_is_match_utf8_scalar,
+        vec![true, true, false, false]
+    );
 }