This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new e6395e21d9 Make regexp_match take scalar pattern and flag (#5245)
e6395e21d9 is described below

commit e6395e21d923caf5b7cd643fc5d4418642d3bb3a
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Jan 1 03:21:53 2024 -0800

    Make regexp_match take scalar pattern and flag (#5245)
    
    * Make regexp_match take Datum pattern input
    
    * Add more tests
    
    * More
    
    * Update benchmark
    
    * Fix clippy
    
    * For review
    
    * Fix clippy
    
    * Don't expose utility function
---
 arrow-string/src/regexp.rs      | 242 ++++++++++++++++++++++++++++++++++++----
 arrow/benches/regexp_kernels.rs |   9 +-
 2 files changed, 227 insertions(+), 24 deletions(-)

diff --git a/arrow-string/src/regexp.rs b/arrow-string/src/regexp.rs
index 25c712d20f..5e539b91b4 100644
--- a/arrow-string/src/regexp.rs
+++ b/arrow-string/src/regexp.rs
@@ -19,10 +19,11 @@
 //! expression of a \[Large\]StringArray
 
 use arrow_array::builder::{BooleanBufferBuilder, GenericStringBuilder, 
ListBuilder};
+use arrow_array::cast::AsArray;
 use arrow_array::*;
 use arrow_buffer::NullBuffer;
 use arrow_data::{ArrayData, ArrayDataBuilder};
-use arrow_schema::{ArrowError, DataType};
+use arrow_schema::{ArrowError, DataType, Field};
 use regex::Regex;
 use std::collections::HashMap;
 use std::sync::Arc;
@@ -152,28 +153,7 @@ pub fn regexp_is_match_utf8_scalar<OffsetSize: 
OffsetSizeTrait>(
     Ok(BooleanArray::from(data))
 }
 
-/// Extract all groups matched by a regular expression for a given String 
array.
-///
-/// Modelled after the Postgres [regexp_match].
-///
-/// Returns a ListArray of [`GenericStringArray`] with each element containing 
the leftmost-first
-/// match of the corresponding index in `regex_array` to string in `array`
-///
-/// If there is no match, the list element is NULL.
-///
-/// If a match is found, and the pattern contains no capturing parenthesized 
subexpressions,
-/// then the list element is a single-element [`GenericStringArray`] 
containing the substring
-/// matching the whole pattern.
-///
-/// If a match is found, and the pattern contains capturing parenthesized 
subexpressions, then the
-/// list element is a [`GenericStringArray`] whose n'th element is the 
substring matching
-/// the n'th capturing parenthesized subexpression of the pattern.
-///
-/// The flags parameter is an optional text string containing zero or more 
single-letter flags
-/// that change the function's behavior.
-///
-/// [regexp_match]: 
https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP
-pub fn regexp_match<OffsetSize: OffsetSizeTrait>(
+fn regexp_array_match<OffsetSize: OffsetSizeTrait>(
     array: &GenericStringArray<OffsetSize>,
     regex_array: &GenericStringArray<OffsetSize>,
     flags_array: Option<&GenericStringArray<OffsetSize>>,
@@ -248,6 +228,179 @@ pub fn regexp_match<OffsetSize: OffsetSizeTrait>(
     Ok(Arc::new(list_builder.finish()))
 }
 
+fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>(
+    regex_array: &'a dyn Array,
+    flag_array: Option<&'a dyn Array>,
+) -> (Option<&'a str>, Option<&'a str>) {
+    let regex = regex_array.as_string::<OffsetSize>();
+    let regex = regex.is_valid(0).then(|| regex.value(0));
+
+    if let Some(flag_array) = flag_array {
+        let flag = flag_array.as_string::<OffsetSize>();
+        (regex, flag.is_valid(0).then(|| flag.value(0)))
+    } else {
+        (regex, None)
+    }
+}
+
+fn regexp_scalar_match<OffsetSize: OffsetSizeTrait>(
+    array: &GenericStringArray<OffsetSize>,
+    regex: &Regex,
+) -> Result<ArrayRef, ArrowError> {
+    let builder: GenericStringBuilder<OffsetSize> = 
GenericStringBuilder::with_capacity(0, 0);
+    let mut list_builder = ListBuilder::new(builder);
+
+    array
+        .iter()
+        .map(|value| {
+            match value {
+                // Required for Postgres compatibility:
+                // SELECT regexp_match('foobarbequebaz', ''); = {""}
+                Some(_) if regex.as_str() == "" => {
+                    list_builder.values().append_value("");
+                    list_builder.append(true);
+                }
+                Some(value) => match regex.captures(value) {
+                    Some(caps) => {
+                        let mut iter = caps.iter();
+                        if caps.len() > 1 {
+                            iter.next();
+                        }
+                        for m in iter.flatten() {
+                            list_builder.values().append_value(m.as_str());
+                        }
+
+                        list_builder.append(true);
+                    }
+                    None => list_builder.append(false),
+                },
+                _ => list_builder.append(false),
+            }
+            Ok(())
+        })
+        .collect::<Result<Vec<()>, ArrowError>>()?;
+
+    Ok(Arc::new(list_builder.finish()))
+}
+
+/// Extract all groups matched by a regular expression for a given String 
array.
+///
+/// Modelled after the Postgres [regexp_match].
+///
+/// Returns a ListArray of [`GenericStringArray`] with each element containing 
the leftmost-first
+/// match of the corresponding index in `regex_array` to string in `array`
+///
+/// If there is no match, the list element is NULL.
+///
+/// If a match is found, and the pattern contains no capturing parenthesized 
subexpressions,
+/// then the list element is a single-element [`GenericStringArray`] 
containing the substring
+/// matching the whole pattern.
+///
+/// If a match is found, and the pattern contains capturing parenthesized 
subexpressions, then the
+/// list element is a [`GenericStringArray`] whose n'th element is the 
substring matching
+/// the n'th capturing parenthesized subexpression of the pattern.
+///
+/// The flags parameter is an optional text string containing zero or more 
single-letter flags
+/// that change the function's behavior.
+///
+/// [regexp_match]: 
https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP
+pub fn regexp_match(
+    array: &dyn Array,
+    regex_array: &dyn Datum,
+    flags_array: Option<&dyn Datum>,
+) -> Result<ArrayRef, ArrowError> {
+    let (rhs, is_rhs_scalar) = regex_array.get();
+
+    if array.data_type() != rhs.data_type() {
+        return Err(ArrowError::ComputeError(
+            "regexp_match() requires both array and pattern to be either Utf8 
or LargeUtf8"
+                .to_string(),
+        ));
+    }
+
+    let (flags, is_flags_scalar) = match flags_array {
+        Some(flags) => {
+            let (flags, is_flags_scalar) = flags.get();
+            (Some(flags), Some(is_flags_scalar))
+        }
+        None => (None, None),
+    };
+
+    if is_flags_scalar.is_some() && is_rhs_scalar != is_flags_scalar.unwrap() {
+        return Err(ArrowError::ComputeError(
+            "regexp_match() requires both pattern and flags to be either 
scalar or array"
+                .to_string(),
+        ));
+    }
+
+    if flags_array.is_some() && rhs.data_type() != flags.unwrap().data_type() {
+        return Err(ArrowError::ComputeError(
+            "regexp_match() requires both pattern and flags to be either 
string or largestring"
+                .to_string(),
+        ));
+    }
+
+    if is_rhs_scalar {
+        // Regex and flag is scalars
+        let (regex, flag) = match rhs.data_type() {
+            DataType::Utf8 => get_scalar_pattern_flag::<i32>(rhs, flags),
+            DataType::LargeUtf8 => get_scalar_pattern_flag::<i64>(rhs, flags),
+            _ => {
+                return Err(ArrowError::ComputeError(
+                    "regexp_match() requires pattern to be either Utf8 or 
LargeUtf8".to_string(),
+                ));
+            }
+        };
+
+        if regex.is_none() {
+            return Ok(new_null_array(
+                &DataType::List(Arc::new(Field::new(
+                    "item",
+                    array.data_type().clone(),
+                    true,
+                ))),
+                array.len(),
+            ));
+        }
+
+        let regex = regex.unwrap();
+
+        let pattern = if let Some(flag) = flag {
+            format!("(?{flag}){regex}")
+        } else {
+            regex.to_string()
+        };
+
+        let re = Regex::new(pattern.as_str()).map_err(|e| {
+            ArrowError::ComputeError(format!("Regular expression did not 
compile: {e:?}"))
+        })?;
+
+        match array.data_type() {
+            DataType::Utf8 => regexp_scalar_match(array.as_string::<i32>(), 
&re),
+            DataType::LargeUtf8 => 
regexp_scalar_match(array.as_string::<i64>(), &re),
+            _ => Err(ArrowError::ComputeError(
+                "regexp_match() requires array to be either Utf8 or 
LargeUtf8".to_string(),
+            )),
+        }
+    } else {
+        match array.data_type() {
+            DataType::Utf8 => {
+                let regex_array = rhs.as_string();
+                let flags_array = flags.map(|flags| flags.as_string());
+                regexp_array_match(array.as_string::<i32>(), regex_array, 
flags_array)
+            }
+            DataType::LargeUtf8 => {
+                let regex_array = rhs.as_string();
+                let flags_array = flags.map(|flags| flags.as_string());
+                regexp_array_match(array.as_string::<i64>(), regex_array, 
flags_array)
+            }
+            _ => Err(ArrowError::ComputeError(
+                "regexp_match() requires array to be either Utf8 or 
LargeUtf8".to_string(),
+            )),
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -304,6 +457,49 @@ mod tests {
         assert_eq!(&expected, result);
     }
 
+    #[test]
+    fn match_scalar_pattern() {
+        let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), 
None];
+        let array = StringArray::from(values);
+        let pattern = Scalar::new(StringArray::from(vec![r"x.*-(\d*)-.*"; 1]));
+        let flags = Scalar::new(StringArray::from(vec!["i"; 1]));
+        let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap();
+        let elem_builder: GenericStringBuilder<i32> = 
GenericStringBuilder::with_capacity(0, 0);
+        let mut expected_builder = ListBuilder::new(elem_builder);
+        expected_builder.append(false);
+        expected_builder.values().append_value("7");
+        expected_builder.append(true);
+        expected_builder.append(false);
+        expected_builder.append(false);
+        let expected = expected_builder.finish();
+        let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
+        assert_eq!(&expected, result);
+
+        // No flag
+        let values = vec![Some("abc-005-def"), Some("x-7-5"), Some("X545"), 
None];
+        let array = StringArray::from(values);
+        let actual = regexp_match(&array, &pattern, None).unwrap();
+        let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
+        assert_eq!(&expected, result);
+    }
+
+    #[test]
+    fn match_scalar_no_pattern() {
+        let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), 
None];
+        let array = StringArray::from(values);
+        let pattern = Scalar::new(new_null_array(&DataType::Utf8, 1));
+        let actual = regexp_match(&array, &pattern, None).unwrap();
+        let elem_builder: GenericStringBuilder<i32> = 
GenericStringBuilder::with_capacity(0, 0);
+        let mut expected_builder = ListBuilder::new(elem_builder);
+        expected_builder.append(false);
+        expected_builder.append(false);
+        expected_builder.append(false);
+        expected_builder.append(false);
+        let expected = expected_builder.finish();
+        let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
+        assert_eq!(&expected, result);
+    }
+
     #[test]
     fn test_single_group_not_skip_match() {
         let array = StringArray::from(vec![Some("foo"), Some("bar")]);
diff --git a/arrow/benches/regexp_kernels.rs b/arrow/benches/regexp_kernels.rs
index eb38ba6783..d5ffbcb997 100644
--- a/arrow/benches/regexp_kernels.rs
+++ b/arrow/benches/regexp_kernels.rs
@@ -25,7 +25,7 @@ use arrow::array::*;
 use arrow::compute::kernels::regexp::*;
 use arrow::util::bench_util::*;
 
-fn bench_regexp(arr: &GenericStringArray<i32>, regex_array: 
&GenericStringArray<i32>) {
+fn bench_regexp(arr: &GenericStringArray<i32>, regex_array: &dyn Datum) {
     regexp_match(criterion::black_box(arr), regex_array, None).unwrap();
 }
 
@@ -38,6 +38,13 @@ fn add_benchmark(c: &mut Criterion) {
     let pattern = GenericStringArray::<i32>::from(pattern_values);
 
     c.bench_function("regexp", |b| b.iter(|| bench_regexp(&arr_string, 
&pattern)));
+
+    let pattern_values = vec![r".*-(\d*)-.*"];
+    let pattern = Scalar::new(GenericStringArray::<i32>::from(pattern_values));
+
+    c.bench_function("regexp scalar", |b| {
+        b.iter(|| bench_regexp(&arr_string, &pattern))
+    });
 }
 
 criterion_group!(benches, add_benchmark);

Reply via email to