This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new dd5f936 Add support for PostgreSQL regex match (#870)
dd5f936 is described below
commit dd5f9365ddefaba0aed9c3f16a679edf55be28d8
Author: baishen <[email protected]>
AuthorDate: Fri Sep 10 05:58:29 2021 -0500
Add support for PostgreSQL regex match (#870)
---
ballista-examples/Cargo.toml | 2 +-
ballista/rust/core/Cargo.toml | 2 +-
ballista/rust/executor/Cargo.toml | 4 +-
datafusion-cli/Cargo.toml | 2 +-
datafusion-examples/Cargo.toml | 2 +-
datafusion/Cargo.toml | 4 +-
datafusion/src/logical_plan/operators.rs | 12 +
datafusion/src/physical_plan/expressions/binary.rs | 246 ++++++++++++++++++++-
datafusion/src/sql/planner.rs | 4 +
datafusion/tests/sql.rs | 62 ++++++
10 files changed, 325 insertions(+), 15 deletions(-)
diff --git a/ballista-examples/Cargo.toml b/ballista-examples/Cargo.toml
index 1b578bf..e0989b9 100644
--- a/ballista-examples/Cargo.toml
+++ b/ballista-examples/Cargo.toml
@@ -28,7 +28,7 @@ edition = "2018"
publish = false
[dependencies]
-arrow-flight = { version = "^5.2" }
+arrow-flight = { version = "^5.3" }
datafusion = { path = "../datafusion" }
ballista = { path = "../ballista/rust/client" }
prost = "0.8"
diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml
index 74beb69..29a0f09 100644
--- a/ballista/rust/core/Cargo.toml
+++ b/ballista/rust/core/Cargo.toml
@@ -42,7 +42,7 @@ tokio = "1.0"
tonic = "0.5"
uuid = { version = "0.8", features = ["v4"] }
-arrow-flight = { version = "^5.2" }
+arrow-flight = { version = "^5.3" }
datafusion = { path = "../../../datafusion", version = "5.1.0" }
diff --git a/ballista/rust/executor/Cargo.toml
b/ballista/rust/executor/Cargo.toml
index 795b8dd..c6e0ab8 100644
--- a/ballista/rust/executor/Cargo.toml
+++ b/ballista/rust/executor/Cargo.toml
@@ -29,8 +29,8 @@ edition = "2018"
snmalloc = ["snmalloc-rs"]
[dependencies]
-arrow = { version = "^5.2" }
-arrow-flight = { version = "^5.2" }
+arrow = { version = "^5.3" }
+arrow-flight = { version = "^5.3" }
anyhow = "1"
async-trait = "0.1.36"
ballista-core = { path = "../core", version = "0.6.0" }
diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml
index 008aeec..3b9be67 100644
--- a/datafusion-cli/Cargo.toml
+++ b/datafusion-cli/Cargo.toml
@@ -31,5 +31,5 @@ clap = "2.33"
rustyline = "8.0"
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread",
"sync"] }
datafusion = { path = "../datafusion", version = "5.1.0" }
-arrow = { version = "^5.2" }
+arrow = { version = "^5.3" }
ballista = { path = "../ballista/rust/client", version = "0.6.0" }
diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml
index 1f4f74d..9b859c6 100644
--- a/datafusion-examples/Cargo.toml
+++ b/datafusion-examples/Cargo.toml
@@ -29,7 +29,7 @@ publish = false
[dev-dependencies]
-arrow-flight = { version = "^5.2" }
+arrow-flight = { version = "^5.3" }
datafusion = { path = "../datafusion" }
prost = "0.8"
tonic = "0.5"
diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml
index 86eb64c..c9ab943 100644
--- a/datafusion/Cargo.toml
+++ b/datafusion/Cargo.toml
@@ -49,8 +49,8 @@ force_hash_collisions = []
[dependencies]
ahash = "0.7"
hashbrown = { version = "0.11", features = ["raw"] }
-arrow = { version = "^5.2", features = ["prettyprint"] }
-parquet = { version = "^5.2", features = ["arrow"] }
+arrow = { version = "^5.3", features = ["prettyprint"] }
+parquet = { version = "^5.3", features = ["arrow"] }
sqlparser = "0.10"
paste = "^1.0"
num_cpus = "1.13.0"
diff --git a/datafusion/src/logical_plan/operators.rs
b/datafusion/src/logical_plan/operators.rs
index 80020d8..b3ff72c 100644
--- a/datafusion/src/logical_plan/operators.rs
+++ b/datafusion/src/logical_plan/operators.rs
@@ -52,6 +52,14 @@ pub enum Operator {
Like,
/// Does not match a wildcard pattern
NotLike,
+ /// Case sensitive regex match
+ RegexMatch,
+ /// Case insensitive regex match
+ RegexIMatch,
+ /// Case sensitive regex not match
+ RegexNotMatch,
+ /// Case insensitive regex not match
+ RegexNotIMatch,
}
impl fmt::Display for Operator {
@@ -72,6 +80,10 @@ impl fmt::Display for Operator {
Operator::Or => "OR",
Operator::Like => "LIKE",
Operator::NotLike => "NOT LIKE",
+ Operator::RegexMatch => "~",
+ Operator::RegexIMatch => "~*",
+ Operator::RegexNotMatch => "!~",
+ Operator::RegexNotIMatch => "!~*",
};
write!(f, "{}", display)
}
diff --git a/datafusion/src/physical_plan/expressions/binary.rs
b/datafusion/src/physical_plan/expressions/binary.rs
index 39999a1..e77b25c 100644
--- a/datafusion/src/physical_plan/expressions/binary.rs
+++ b/datafusion/src/physical_plan/expressions/binary.rs
@@ -22,18 +22,19 @@ use arrow::array::*;
use arrow::compute::kernels::arithmetic::{
add, divide, divide_scalar, modulus, modulus_scalar, multiply, subtract,
};
-use arrow::compute::kernels::boolean::{and_kleene, or_kleene};
+use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene};
use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq};
use arrow::compute::kernels::comparison::{
eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar,
};
use arrow::compute::kernels::comparison::{
- eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, like_utf8_scalar, lt_eq_utf8,
lt_utf8,
- neq_utf8, nlike_utf8, nlike_utf8_scalar,
+ eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, lt_eq_utf8, lt_utf8, neq_utf8,
nlike_utf8,
+ regexp_is_match_utf8,
};
use arrow::compute::kernels::comparison::{
- eq_utf8_scalar, gt_eq_utf8_scalar, gt_utf8_scalar, lt_eq_utf8_scalar,
lt_utf8_scalar,
- neq_utf8_scalar,
+ eq_utf8_scalar, gt_eq_utf8_scalar, gt_utf8_scalar, like_utf8_scalar,
+ lt_eq_utf8_scalar, lt_utf8_scalar, neq_utf8_scalar, nlike_utf8_scalar,
+ regexp_is_match_utf8_scalar,
};
use arrow::datatypes::{DataType, Schema, TimeUnit};
use arrow::record_batch::RecordBatch;
@@ -44,7 +45,9 @@ use crate::physical_plan::expressions::try_cast;
use crate::physical_plan::{ColumnarValue, PhysicalExpr};
use crate::scalar::ScalarValue;
-use super::coercion::{eq_coercion, like_coercion, numerical_coercion,
order_coercion};
+use super::coercion::{
+ eq_coercion, like_coercion, numerical_coercion, order_coercion,
string_coercion,
+};
/// Binary expression
#[derive(Debug)]
@@ -339,6 +342,91 @@ macro_rules! boolean_op {
}};
}
+macro_rules! binary_string_array_flag_op {
+ ($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{
+ match $LEFT.data_type() {
+ DataType::Utf8 => {
+ compute_utf8_flag_op!($LEFT, $RIGHT, $OP, StringArray, $NOT,
$FLAG)
+ }
+ DataType::LargeUtf8 => {
+ compute_utf8_flag_op!($LEFT, $RIGHT, $OP, LargeStringArray,
$NOT, $FLAG)
+ }
+ other => Err(DataFusionError::Internal(format!(
+ "Data type {:?} not supported for binary_string_array_flag_op
operation on string array",
+ other
+ ))),
+ }
+ }};
+}
+
+/// Invoke a compute kernel on a pair of binary data arrays with flags
+macro_rules! compute_utf8_flag_op {
+ ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr,
$FLAG:expr) => {{
+ let ll = $LEFT
+ .as_any()
+ .downcast_ref::<$ARRAYTYPE>()
+ .expect("compute_utf8_flag_op failed to downcast array");
+ let rr = $RIGHT
+ .as_any()
+ .downcast_ref::<$ARRAYTYPE>()
+ .expect("compute_utf8_flag_op failed to downcast array");
+
+ let flag = if $FLAG {
+ Some($ARRAYTYPE::from(vec!["i"; ll.len()]))
+ } else {
+ None
+ };
+ let mut array = paste::expr! {[<$OP _utf8>]}(&ll, &rr, flag.as_ref())?;
+ if $NOT {
+ array = not(&array).unwrap();
+ }
+ Ok(Arc::new(array))
+ }};
+}
+
+macro_rules! binary_string_array_flag_op_scalar {
+ ($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{
+ let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
+ DataType::Utf8 => {
+ compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, StringArray,
$NOT, $FLAG)
+ }
+ DataType::LargeUtf8 => {
+ compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP,
LargeStringArray, $NOT, $FLAG)
+ }
+ other => Err(DataFusionError::Internal(format!(
+ "Data type {:?} not supported for
binary_string_array_flag_op_scalar operation on string array",
+ other
+ ))),
+ };
+ Some(result)
+ }};
+}
+
+/// Invoke a compute kernel on a data array and a scalar value with flag
+macro_rules! compute_utf8_flag_op_scalar {
+ ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr,
$FLAG:expr) => {{
+ let ll = $LEFT
+ .as_any()
+ .downcast_ref::<$ARRAYTYPE>()
+ .expect("compute_utf8_flag_op_scalar failed to downcast array");
+
+ if let ScalarValue::Utf8(Some(string_value)) = $RIGHT {
+ let flag = if $FLAG { Some("i") } else { None };
+ let mut array =
+ paste::expr! {[<$OP _utf8_scalar>]}(&ll, &string_value, flag)?;
+ if $NOT {
+ array = not(&array).unwrap();
+ }
+ Ok(Arc::new(array))
+ } else {
+ Err(DataFusionError::Internal(format!(
+ "compute_utf8_flag_op_scalar failed to cast literal value {}",
+ $RIGHT
+ )))
+ }
+ }};
+}
+
/// Coercion rules for all binary operators. Returns the output type
/// of applying `op` to an argument of `lhs_type` and `rhs_type`.
fn common_binary_type(
@@ -368,6 +456,10 @@ fn common_binary_type(
| Operator::Modulo
| Operator::Divide
| Operator::Multiply => numerical_coercion(lhs_type, rhs_type),
+ Operator::RegexMatch
+ | Operator::RegexIMatch
+ | Operator::RegexNotMatch
+ | Operator::RegexNotIMatch => string_coercion(lhs_type, rhs_type),
};
// re-write the error message of failed coercions to include the
operator's information
@@ -406,7 +498,11 @@ pub fn binary_operator_data_type(
| Operator::Lt
| Operator::Gt
| Operator::GtEq
- | Operator::LtEq => Ok(DataType::Boolean),
+ | Operator::LtEq
+ | Operator::RegexMatch
+ | Operator::RegexIMatch
+ | Operator::RegexNotMatch
+ | Operator::RegexNotIMatch => Ok(DataType::Boolean),
// math operations return the same value as the common coerced type
Operator::Plus
| Operator::Minus
@@ -475,6 +571,34 @@ impl PhysicalExpr for BinaryExpr {
Operator::Modulo => {
binary_primitive_array_op_scalar!(array,
scalar.clone(), modulus)
}
+ Operator::RegexMatch =>
binary_string_array_flag_op_scalar!(
+ array,
+ scalar.clone(),
+ regexp_is_match,
+ false,
+ false
+ ),
+ Operator::RegexIMatch =>
binary_string_array_flag_op_scalar!(
+ array,
+ scalar.clone(),
+ regexp_is_match,
+ false,
+ true
+ ),
+ Operator::RegexNotMatch =>
binary_string_array_flag_op_scalar!(
+ array,
+ scalar.clone(),
+ regexp_is_match,
+ true,
+ false
+ ),
+ Operator::RegexNotIMatch =>
binary_string_array_flag_op_scalar!(
+ array,
+ scalar.clone(),
+ regexp_is_match,
+ true,
+ true
+ ),
// if scalar operation is not supported - fallback to
array implementation
_ => None,
}
@@ -547,6 +671,18 @@ impl PhysicalExpr for BinaryExpr {
)));
}
}
+ Operator::RegexMatch => {
+ binary_string_array_flag_op!(left, right, regexp_is_match,
false, false)
+ }
+ Operator::RegexIMatch => {
+ binary_string_array_flag_op!(left, right, regexp_is_match,
false, true)
+ }
+ Operator::RegexNotMatch => {
+ binary_string_array_flag_op!(left, right, regexp_is_match,
true, false)
+ }
+ Operator::RegexNotIMatch => {
+ binary_string_array_flag_op!(left, right, regexp_is_match,
true, true)
+ }
};
result.map(|a| ColumnarValue::Array(a))
}
@@ -822,6 +958,102 @@ mod tests {
DataType::Boolean,
vec![true, false]
);
+ test_coercion!(
+ StringArray,
+ DataType::Utf8,
+ vec!["abc"; 5],
+ StringArray,
+ DataType::Utf8,
+ vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
+ Operator::RegexMatch,
+ BooleanArray,
+ DataType::Boolean,
+ vec![true, false, true, false, false]
+ );
+ test_coercion!(
+ StringArray,
+ DataType::Utf8,
+ vec!["abc"; 5],
+ StringArray,
+ DataType::Utf8,
+ vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
+ Operator::RegexIMatch,
+ BooleanArray,
+ DataType::Boolean,
+ vec![true, true, true, true, false]
+ );
+ test_coercion!(
+ StringArray,
+ DataType::Utf8,
+ vec!["abc"; 5],
+ StringArray,
+ DataType::Utf8,
+ vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
+ Operator::RegexNotMatch,
+ BooleanArray,
+ DataType::Boolean,
+ vec![false, true, false, true, true]
+ );
+ test_coercion!(
+ StringArray,
+ DataType::Utf8,
+ vec!["abc"; 5],
+ StringArray,
+ DataType::Utf8,
+ vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
+ Operator::RegexNotIMatch,
+ BooleanArray,
+ DataType::Boolean,
+ vec![false, false, false, false, true]
+ );
+ test_coercion!(
+ LargeStringArray,
+ DataType::LargeUtf8,
+ vec!["abc"; 5],
+ LargeStringArray,
+ DataType::LargeUtf8,
+ vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
+ Operator::RegexMatch,
+ BooleanArray,
+ DataType::Boolean,
+ vec![true, false, true, false, false]
+ );
+ test_coercion!(
+ LargeStringArray,
+ DataType::LargeUtf8,
+ vec!["abc"; 5],
+ LargeStringArray,
+ DataType::LargeUtf8,
+ vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
+ Operator::RegexIMatch,
+ BooleanArray,
+ DataType::Boolean,
+ vec![true, true, true, true, false]
+ );
+ test_coercion!(
+ LargeStringArray,
+ DataType::LargeUtf8,
+ vec!["abc"; 5],
+ LargeStringArray,
+ DataType::LargeUtf8,
+ vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
+ Operator::RegexNotMatch,
+ BooleanArray,
+ DataType::Boolean,
+ vec![false, true, false, true, true]
+ );
+ test_coercion!(
+ LargeStringArray,
+ DataType::LargeUtf8,
+ vec!["abc"; 5],
+ LargeStringArray,
+ DataType::LargeUtf8,
+ vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
+ Operator::RegexNotIMatch,
+ BooleanArray,
+ DataType::Boolean,
+ vec![false, false, false, false, true]
+ );
Ok(())
}
diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs
index 6695d89..e613ff3 100644
--- a/datafusion/src/sql/planner.rs
+++ b/datafusion/src/sql/planner.rs
@@ -1261,6 +1261,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
BinaryOperator::Or => Ok(Operator::Or),
BinaryOperator::Like => Ok(Operator::Like),
BinaryOperator::NotLike => Ok(Operator::NotLike),
+ BinaryOperator::PGRegexMatch => Ok(Operator::RegexMatch),
+ BinaryOperator::PGRegexIMatch => Ok(Operator::RegexIMatch),
+ BinaryOperator::PGRegexNotMatch =>
Ok(Operator::RegexNotMatch),
+ BinaryOperator::PGRegexNotIMatch =>
Ok(Operator::RegexNotIMatch),
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported SQL binary operator {:?}",
op
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index 40cd38f..804ae7e 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -4483,3 +4483,65 @@ async fn like_on_string_dictionaries() -> Result<()> {
assert_batches_eq!(expected, &actual);
Ok(())
}
+
+#[tokio::test]
+async fn test_regexp_is_match() -> Result<()> {
+ let input = vec![Some("foo"), Some("Barrr"), Some("Bazzz"), Some("ZZZZZ")]
+ .into_iter()
+ .collect::<StringArray>();
+
+ let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as
_)]).unwrap();
+
+ let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
+ let mut ctx = ExecutionContext::new();
+ ctx.register_table("test", Arc::new(table))?;
+
+ let sql = "SELECT * FROM test WHERE c1 ~ 'z'";
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ let expected = vec![
+ "+-------+",
+ "| c1 |",
+ "+-------+",
+ "| Bazzz |",
+ "+-------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "SELECT * FROM test WHERE c1 ~* 'z'";
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ let expected = vec![
+ "+-------+",
+ "| c1 |",
+ "+-------+",
+ "| Bazzz |",
+ "| ZZZZZ |",
+ "+-------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "SELECT * FROM test WHERE c1 !~ 'z'";
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ let expected = vec![
+ "+-------+",
+ "| c1 |",
+ "+-------+",
+ "| foo |",
+ "| Barrr |",
+ "| ZZZZZ |",
+ "+-------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "SELECT * FROM test WHERE c1 !~* 'z'";
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ let expected = vec![
+ "+-------+",
+ "| c1 |",
+ "+-------+",
+ "| foo |",
+ "| Barrr |",
+ "+-------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}