This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new dd5f936  Add support for PostgreSQL regex match (#870)
dd5f936 is described below

commit dd5f9365ddefaba0aed9c3f16a679edf55be28d8
Author: baishen <[email protected]>
AuthorDate: Fri Sep 10 05:58:29 2021 -0500

    Add support for PostgreSQL regex match (#870)
---
 ballista-examples/Cargo.toml                       |   2 +-
 ballista/rust/core/Cargo.toml                      |   2 +-
 ballista/rust/executor/Cargo.toml                  |   4 +-
 datafusion-cli/Cargo.toml                          |   2 +-
 datafusion-examples/Cargo.toml                     |   2 +-
 datafusion/Cargo.toml                              |   4 +-
 datafusion/src/logical_plan/operators.rs           |  12 +
 datafusion/src/physical_plan/expressions/binary.rs | 246 ++++++++++++++++++++-
 datafusion/src/sql/planner.rs                      |   4 +
 datafusion/tests/sql.rs                            |  62 ++++++
 10 files changed, 325 insertions(+), 15 deletions(-)

diff --git a/ballista-examples/Cargo.toml b/ballista-examples/Cargo.toml
index 1b578bf..e0989b9 100644
--- a/ballista-examples/Cargo.toml
+++ b/ballista-examples/Cargo.toml
@@ -28,7 +28,7 @@ edition = "2018"
 publish = false
 
 [dependencies]
-arrow-flight = { version = "^5.2" }
+arrow-flight = { version = "^5.3" }
 datafusion = { path = "../datafusion" }
 ballista = { path = "../ballista/rust/client" }
 prost = "0.8"
diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml
index 74beb69..29a0f09 100644
--- a/ballista/rust/core/Cargo.toml
+++ b/ballista/rust/core/Cargo.toml
@@ -42,7 +42,7 @@ tokio = "1.0"
 tonic = "0.5"
 uuid = { version = "0.8", features = ["v4"] }
 
-arrow-flight = { version = "^5.2"  }
+arrow-flight = { version = "^5.3"  }
 
 datafusion = { path = "../../../datafusion", version = "5.1.0" }
 
diff --git a/ballista/rust/executor/Cargo.toml 
b/ballista/rust/executor/Cargo.toml
index 795b8dd..c6e0ab8 100644
--- a/ballista/rust/executor/Cargo.toml
+++ b/ballista/rust/executor/Cargo.toml
@@ -29,8 +29,8 @@ edition = "2018"
 snmalloc = ["snmalloc-rs"]
 
 [dependencies]
-arrow = { version = "^5.2"  }
-arrow-flight = { version = "^5.2"  }
+arrow = { version = "^5.3"  }
+arrow-flight = { version = "^5.3"  }
 anyhow = "1"
 async-trait = "0.1.36"
 ballista-core = { path = "../core", version = "0.6.0" }
diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml
index 008aeec..3b9be67 100644
--- a/datafusion-cli/Cargo.toml
+++ b/datafusion-cli/Cargo.toml
@@ -31,5 +31,5 @@ clap = "2.33"
 rustyline = "8.0"
 tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", 
"sync"] }
 datafusion = { path = "../datafusion", version = "5.1.0" }
-arrow = { version = "^5.2"  }
+arrow = { version = "^5.3"  }
 ballista = { path = "../ballista/rust/client", version = "0.6.0" }
diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml
index 1f4f74d..9b859c6 100644
--- a/datafusion-examples/Cargo.toml
+++ b/datafusion-examples/Cargo.toml
@@ -29,7 +29,7 @@ publish = false
 
 
 [dev-dependencies]
-arrow-flight = { version = "^5.2" }
+arrow-flight = { version = "^5.3" }
 datafusion = { path = "../datafusion" }
 prost = "0.8"
 tonic = "0.5"
diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml
index 86eb64c..c9ab943 100644
--- a/datafusion/Cargo.toml
+++ b/datafusion/Cargo.toml
@@ -49,8 +49,8 @@ force_hash_collisions = []
 [dependencies]
 ahash = "0.7"
 hashbrown = { version = "0.11", features = ["raw"] }
-arrow = { version = "^5.2", features = ["prettyprint"] }
-parquet = { version = "^5.2", features = ["arrow"] }
+arrow = { version = "^5.3", features = ["prettyprint"] }
+parquet = { version = "^5.3", features = ["arrow"] }
 sqlparser = "0.10"
 paste = "^1.0"
 num_cpus = "1.13.0"
diff --git a/datafusion/src/logical_plan/operators.rs 
b/datafusion/src/logical_plan/operators.rs
index 80020d8..b3ff72c 100644
--- a/datafusion/src/logical_plan/operators.rs
+++ b/datafusion/src/logical_plan/operators.rs
@@ -52,6 +52,14 @@ pub enum Operator {
     Like,
     /// Does not match a wildcard pattern
     NotLike,
+    /// Case sensitive regex match
+    RegexMatch,
+    /// Case insensitive regex match
+    RegexIMatch,
+    /// Case sensitive regex not match
+    RegexNotMatch,
+    /// Case insensitive regex not match
+    RegexNotIMatch,
 }
 
 impl fmt::Display for Operator {
@@ -72,6 +80,10 @@ impl fmt::Display for Operator {
             Operator::Or => "OR",
             Operator::Like => "LIKE",
             Operator::NotLike => "NOT LIKE",
+            Operator::RegexMatch => "~",
+            Operator::RegexIMatch => "~*",
+            Operator::RegexNotMatch => "!~",
+            Operator::RegexNotIMatch => "!~*",
         };
         write!(f, "{}", display)
     }
diff --git a/datafusion/src/physical_plan/expressions/binary.rs 
b/datafusion/src/physical_plan/expressions/binary.rs
index 39999a1..e77b25c 100644
--- a/datafusion/src/physical_plan/expressions/binary.rs
+++ b/datafusion/src/physical_plan/expressions/binary.rs
@@ -22,18 +22,19 @@ use arrow::array::*;
 use arrow::compute::kernels::arithmetic::{
     add, divide, divide_scalar, modulus, modulus_scalar, multiply, subtract,
 };
-use arrow::compute::kernels::boolean::{and_kleene, or_kleene};
+use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene};
 use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq};
 use arrow::compute::kernels::comparison::{
     eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar,
 };
 use arrow::compute::kernels::comparison::{
-    eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, like_utf8_scalar, lt_eq_utf8, 
lt_utf8,
-    neq_utf8, nlike_utf8, nlike_utf8_scalar,
+    eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, lt_eq_utf8, lt_utf8, neq_utf8, 
nlike_utf8,
+    regexp_is_match_utf8,
 };
 use arrow::compute::kernels::comparison::{
-    eq_utf8_scalar, gt_eq_utf8_scalar, gt_utf8_scalar, lt_eq_utf8_scalar, 
lt_utf8_scalar,
-    neq_utf8_scalar,
+    eq_utf8_scalar, gt_eq_utf8_scalar, gt_utf8_scalar, like_utf8_scalar,
+    lt_eq_utf8_scalar, lt_utf8_scalar, neq_utf8_scalar, nlike_utf8_scalar,
+    regexp_is_match_utf8_scalar,
 };
 use arrow::datatypes::{DataType, Schema, TimeUnit};
 use arrow::record_batch::RecordBatch;
@@ -44,7 +45,9 @@ use crate::physical_plan::expressions::try_cast;
 use crate::physical_plan::{ColumnarValue, PhysicalExpr};
 use crate::scalar::ScalarValue;
 
-use super::coercion::{eq_coercion, like_coercion, numerical_coercion, 
order_coercion};
+use super::coercion::{
+    eq_coercion, like_coercion, numerical_coercion, order_coercion, 
string_coercion,
+};
 
 /// Binary expression
 #[derive(Debug)]
@@ -339,6 +342,91 @@ macro_rules! boolean_op {
     }};
 }
 
+macro_rules! binary_string_array_flag_op {
+    ($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{
+        match $LEFT.data_type() {
+            DataType::Utf8 => {
+                compute_utf8_flag_op!($LEFT, $RIGHT, $OP, StringArray, $NOT, 
$FLAG)
+            }
+            DataType::LargeUtf8 => {
+                compute_utf8_flag_op!($LEFT, $RIGHT, $OP, LargeStringArray, 
$NOT, $FLAG)
+            }
+            other => Err(DataFusionError::Internal(format!(
+                "Data type {:?} not supported for binary_string_array_flag_op 
operation on string array",
+                other
+            ))),
+        }
+    }};
+}
+
+/// Invoke a compute kernel on a pair of binary data arrays with flags
+macro_rules! compute_utf8_flag_op {
+    ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, 
$FLAG:expr) => {{
+        let ll = $LEFT
+            .as_any()
+            .downcast_ref::<$ARRAYTYPE>()
+            .expect("compute_utf8_flag_op failed to downcast array");
+        let rr = $RIGHT
+            .as_any()
+            .downcast_ref::<$ARRAYTYPE>()
+            .expect("compute_utf8_flag_op failed to downcast array");
+
+        let flag = if $FLAG {
+            Some($ARRAYTYPE::from(vec!["i"; ll.len()]))
+        } else {
+            None
+        };
+        let mut array = paste::expr! {[<$OP _utf8>]}(&ll, &rr, flag.as_ref())?;
+        if $NOT {
+            array = not(&array).unwrap();
+        }
+        Ok(Arc::new(array))
+    }};
+}
+
+macro_rules! binary_string_array_flag_op_scalar {
+    ($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{
+        let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
+            DataType::Utf8 => {
+                compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, StringArray, 
$NOT, $FLAG)
+            }
+            DataType::LargeUtf8 => {
+                compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, 
LargeStringArray, $NOT, $FLAG)
+            }
+            other => Err(DataFusionError::Internal(format!(
+                "Data type {:?} not supported for 
binary_string_array_flag_op_scalar operation on string array",
+                other
+            ))),
+        };
+        Some(result)
+    }};
+}
+
+/// Invoke a compute kernel on a data array and a scalar value with flag
+macro_rules! compute_utf8_flag_op_scalar {
+    ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, 
$FLAG:expr) => {{
+        let ll = $LEFT
+            .as_any()
+            .downcast_ref::<$ARRAYTYPE>()
+            .expect("compute_utf8_flag_op_scalar failed to downcast array");
+
+        if let ScalarValue::Utf8(Some(string_value)) = $RIGHT {
+            let flag = if $FLAG { Some("i") } else { None };
+            let mut array =
+                paste::expr! {[<$OP _utf8_scalar>]}(&ll, &string_value, flag)?;
+            if $NOT {
+                array = not(&array).unwrap();
+            }
+            Ok(Arc::new(array))
+        } else {
+            Err(DataFusionError::Internal(format!(
+                "compute_utf8_flag_op_scalar failed to cast literal value {}",
+                $RIGHT
+            )))
+        }
+    }};
+}
+
 /// Coercion rules for all binary operators. Returns the output type
 /// of applying `op` to an argument of `lhs_type` and `rhs_type`.
 fn common_binary_type(
@@ -368,6 +456,10 @@ fn common_binary_type(
         | Operator::Modulo
         | Operator::Divide
         | Operator::Multiply => numerical_coercion(lhs_type, rhs_type),
+        Operator::RegexMatch
+        | Operator::RegexIMatch
+        | Operator::RegexNotMatch
+        | Operator::RegexNotIMatch => string_coercion(lhs_type, rhs_type),
     };
 
     // re-write the error message of failed coercions to include the 
operator's information
@@ -406,7 +498,11 @@ pub fn binary_operator_data_type(
         | Operator::Lt
         | Operator::Gt
         | Operator::GtEq
-        | Operator::LtEq => Ok(DataType::Boolean),
+        | Operator::LtEq
+        | Operator::RegexMatch
+        | Operator::RegexIMatch
+        | Operator::RegexNotMatch
+        | Operator::RegexNotIMatch => Ok(DataType::Boolean),
         // math operations return the same value as the common coerced type
         Operator::Plus
         | Operator::Minus
@@ -475,6 +571,34 @@ impl PhysicalExpr for BinaryExpr {
                     Operator::Modulo => {
                         binary_primitive_array_op_scalar!(array, 
scalar.clone(), modulus)
                     }
+                    Operator::RegexMatch => 
binary_string_array_flag_op_scalar!(
+                        array,
+                        scalar.clone(),
+                        regexp_is_match,
+                        false,
+                        false
+                    ),
+                    Operator::RegexIMatch => 
binary_string_array_flag_op_scalar!(
+                        array,
+                        scalar.clone(),
+                        regexp_is_match,
+                        false,
+                        true
+                    ),
+                    Operator::RegexNotMatch => 
binary_string_array_flag_op_scalar!(
+                        array,
+                        scalar.clone(),
+                        regexp_is_match,
+                        true,
+                        false
+                    ),
+                    Operator::RegexNotIMatch => 
binary_string_array_flag_op_scalar!(
+                        array,
+                        scalar.clone(),
+                        regexp_is_match,
+                        true,
+                        true
+                    ),
                     // if scalar operation is not supported - fallback to 
array implementation
                     _ => None,
                 }
@@ -547,6 +671,18 @@ impl PhysicalExpr for BinaryExpr {
                     )));
                 }
             }
+            Operator::RegexMatch => {
+                binary_string_array_flag_op!(left, right, regexp_is_match, 
false, false)
+            }
+            Operator::RegexIMatch => {
+                binary_string_array_flag_op!(left, right, regexp_is_match, 
false, true)
+            }
+            Operator::RegexNotMatch => {
+                binary_string_array_flag_op!(left, right, regexp_is_match, 
true, false)
+            }
+            Operator::RegexNotIMatch => {
+                binary_string_array_flag_op!(left, right, regexp_is_match, 
true, true)
+            }
         };
         result.map(|a| ColumnarValue::Array(a))
     }
@@ -822,6 +958,102 @@ mod tests {
             DataType::Boolean,
             vec![true, false]
         );
+        test_coercion!(
+            StringArray,
+            DataType::Utf8,
+            vec!["abc"; 5],
+            StringArray,
+            DataType::Utf8,
+            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
+            Operator::RegexMatch,
+            BooleanArray,
+            DataType::Boolean,
+            vec![true, false, true, false, false]
+        );
+        test_coercion!(
+            StringArray,
+            DataType::Utf8,
+            vec!["abc"; 5],
+            StringArray,
+            DataType::Utf8,
+            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
+            Operator::RegexIMatch,
+            BooleanArray,
+            DataType::Boolean,
+            vec![true, true, true, true, false]
+        );
+        test_coercion!(
+            StringArray,
+            DataType::Utf8,
+            vec!["abc"; 5],
+            StringArray,
+            DataType::Utf8,
+            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
+            Operator::RegexNotMatch,
+            BooleanArray,
+            DataType::Boolean,
+            vec![false, true, false, true, true]
+        );
+        test_coercion!(
+            StringArray,
+            DataType::Utf8,
+            vec!["abc"; 5],
+            StringArray,
+            DataType::Utf8,
+            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
+            Operator::RegexNotIMatch,
+            BooleanArray,
+            DataType::Boolean,
+            vec![false, false, false, false, true]
+        );
+        test_coercion!(
+            LargeStringArray,
+            DataType::LargeUtf8,
+            vec!["abc"; 5],
+            LargeStringArray,
+            DataType::LargeUtf8,
+            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
+            Operator::RegexMatch,
+            BooleanArray,
+            DataType::Boolean,
+            vec![true, false, true, false, false]
+        );
+        test_coercion!(
+            LargeStringArray,
+            DataType::LargeUtf8,
+            vec!["abc"; 5],
+            LargeStringArray,
+            DataType::LargeUtf8,
+            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
+            Operator::RegexIMatch,
+            BooleanArray,
+            DataType::Boolean,
+            vec![true, true, true, true, false]
+        );
+        test_coercion!(
+            LargeStringArray,
+            DataType::LargeUtf8,
+            vec!["abc"; 5],
+            LargeStringArray,
+            DataType::LargeUtf8,
+            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
+            Operator::RegexNotMatch,
+            BooleanArray,
+            DataType::Boolean,
+            vec![false, true, false, true, true]
+        );
+        test_coercion!(
+            LargeStringArray,
+            DataType::LargeUtf8,
+            vec!["abc"; 5],
+            LargeStringArray,
+            DataType::LargeUtf8,
+            vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
+            Operator::RegexNotIMatch,
+            BooleanArray,
+            DataType::Boolean,
+            vec![false, false, false, false, true]
+        );
         Ok(())
     }
 
diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs
index 6695d89..e613ff3 100644
--- a/datafusion/src/sql/planner.rs
+++ b/datafusion/src/sql/planner.rs
@@ -1261,6 +1261,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                     BinaryOperator::Or => Ok(Operator::Or),
                     BinaryOperator::Like => Ok(Operator::Like),
                     BinaryOperator::NotLike => Ok(Operator::NotLike),
+                    BinaryOperator::PGRegexMatch => Ok(Operator::RegexMatch),
+                    BinaryOperator::PGRegexIMatch => Ok(Operator::RegexIMatch),
+                    BinaryOperator::PGRegexNotMatch => 
Ok(Operator::RegexNotMatch),
+                    BinaryOperator::PGRegexNotIMatch => 
Ok(Operator::RegexNotIMatch),
                     _ => Err(DataFusionError::NotImplemented(format!(
                         "Unsupported SQL binary operator {:?}",
                         op
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index 40cd38f..804ae7e 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -4483,3 +4483,65 @@ async fn like_on_string_dictionaries() -> Result<()> {
     assert_batches_eq!(expected, &actual);
     Ok(())
 }
+
+#[tokio::test]
+async fn test_regexp_is_match() -> Result<()> {
+    let input = vec![Some("foo"), Some("Barrr"), Some("Bazzz"), Some("ZZZZZ")]
+        .into_iter()
+        .collect::<StringArray>();
+
+    let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as 
_)]).unwrap();
+
+    let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
+    let mut ctx = ExecutionContext::new();
+    ctx.register_table("test", Arc::new(table))?;
+
+    let sql = "SELECT * FROM test WHERE c1 ~ 'z'";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    let expected = vec![
+        "+-------+",
+        "| c1    |",
+        "+-------+",
+        "| Bazzz |",
+        "+-------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "SELECT * FROM test WHERE c1 ~* 'z'";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    let expected = vec![
+        "+-------+",
+        "| c1    |",
+        "+-------+",
+        "| Bazzz |",
+        "| ZZZZZ |",
+        "+-------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "SELECT * FROM test WHERE c1 !~ 'z'";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    let expected = vec![
+        "+-------+",
+        "| c1    |",
+        "+-------+",
+        "| foo   |",
+        "| Barrr |",
+        "| ZZZZZ |",
+        "+-------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "SELECT * FROM test WHERE c1 !~* 'z'";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    let expected = vec![
+        "+-------+",
+        "| c1    |",
+        "+-------+",
+        "| foo   |",
+        "| Barrr |",
+        "+-------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}

Reply via email to