This is an automated email from the ASF dual-hosted git repository.
wayne pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 16303ada9b feat: implement substrait for LIKE/ILIKE expr (#6840)
16303ada9b is described below
commit 16303ada9bda90c89b28c0d6d3600782967da03a
Author: Ruihang Xia <[email protected]>
AuthorDate: Fri Jul 14 11:31:12 2023 +0800
feat: implement substrait for LIKE/ILIKE expr (#6840)
* feat: implement substrait for LIKE/ILIKE expr
Signed-off-by: Ruihang Xia <[email protected]>
* fix clippy
Signed-off-by: Ruihang Xia <[email protected]>
* Apply suggestions from code review
Co-authored-by: Nuttiiya Seekhao
<[email protected]>
* Update datafusion/substrait/src/logical_plan/consumer.rs
Co-authored-by: Nuttiiya Seekhao
<[email protected]>
* style: rename function
Signed-off-by: Ruihang Xia <[email protected]>
* apply CR sugg.
Signed-off-by: Ruihang Xia <[email protected]>
---------
Signed-off-by: Ruihang Xia <[email protected]>
Co-authored-by: Nuttiiya Seekhao
<[email protected]>
---
datafusion/substrait/src/logical_plan/consumer.rs | 279 +++++++++++----------
datafusion/substrait/src/logical_plan/producer.rs | 96 +++++++
.../tests/cases/roundtrip_logical_plan.rs | 10 +
3 files changed, 249 insertions(+), 136 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index dc06b64a9e..7b54bea493 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -23,7 +23,7 @@ use datafusion::logical_expr::{
BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator,
};
use datafusion::logical_expr::{expr, Cast, WindowFrameBound, WindowFrameUnits};
-use datafusion::logical_expr::{Extension, LogicalPlanBuilder};
+use datafusion::logical_expr::{Extension, Like, LogicalPlanBuilder};
use datafusion::prelude::JoinType;
use datafusion::sql::TableReference;
use datafusion::{
@@ -32,7 +32,7 @@ use datafusion::{
prelude::{Column, SessionContext},
scalar::ScalarValue,
};
-use substrait::proto::expression::Literal;
+use substrait::proto::expression::{Literal, ScalarFunction};
use substrait::proto::{
aggregate_function::AggregationInvocation,
expression::{
@@ -67,8 +67,12 @@ use crate::variation_const::{
enum ScalarFunctionType {
Builtin(BuiltinScalarFunction),
Op(Operator),
- // logical negation
+ /// [Expr::Not]
Not,
+ /// [Expr::Like] Used for filtering rows based on the given wildcard
pattern. Case sensitive
+ Like,
+ /// [Expr::ILike] Case insensitive operator counterpart of `Like`
+ ILike,
}
pub fn name_to_op(name: &str) -> Result<Operator> {
@@ -104,7 +108,7 @@ pub fn name_to_op(name: &str) -> Result<Operator> {
}
}
-fn name_to_op_or_scalar_function(name: &str) -> Result<ScalarFunctionType> {
+fn scalar_function_type_from_str(name: &str) -> Result<ScalarFunctionType> {
if let Ok(op) = name_to_op(name) {
return Ok(ScalarFunctionType::Op(op));
}
@@ -113,23 +117,14 @@ fn name_to_op_or_scalar_function(name: &str) ->
Result<ScalarFunctionType> {
return Ok(ScalarFunctionType::Builtin(fun));
}
- Err(DataFusionError::NotImplemented(format!(
- "Unsupported function name: {name:?}"
- )))
-}
-
-fn scalar_function_or_not(name: &str) -> Result<ScalarFunctionType> {
- if let Ok(fun) = BuiltinScalarFunction::from_str(name) {
- return Ok(ScalarFunctionType::Builtin(fun));
- }
-
- if name == "not" {
- return Ok(ScalarFunctionType::Not);
+ match name {
+ "not" => Ok(ScalarFunctionType::Not),
+ "like" => Ok(ScalarFunctionType::Like),
+ "ilike" => Ok(ScalarFunctionType::ILike),
+ others => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported function name: {others:?}"
+ ))),
}
-
- Err(DataFusionError::NotImplemented(format!(
- "Unsupported function name: {name:?}"
- )))
}
/// Convert Substrait Plan to DataFusion DataFrame
@@ -790,20 +785,46 @@ pub async fn from_substrait_rex(
else_expr,
})))
}
- Some(RexType::ScalarFunction(f)) => match f.arguments.len() {
- // BinaryExpr or ScalarFunction
- 2 => match (&f.arguments[0].arg_type, &f.arguments[1].arg_type) {
- (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => {
- let op_or_fun = match
extensions.get(&f.function_reference) {
- Some(fname) => name_to_op_or_scalar_function(fname),
- None => Err(DataFusionError::NotImplemented(format!(
- "Aggregated function not found: function reference
= {:?}",
- f.function_reference
- ))),
- };
- match op_or_fun {
- Ok(ScalarFunctionType::Op(op)) => {
- return Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
+ Some(RexType::ScalarFunction(f)) => {
+ let fn_name = extensions.get(&f.function_reference).ok_or_else(|| {
+ DataFusionError::NotImplemented(format!(
+ "Aggregated function not found: function reference = {:?}",
+ f.function_reference
+ ))
+ })?;
+ let fn_type = scalar_function_type_from_str(fn_name)?;
+ match fn_type {
+ ScalarFunctionType::Builtin(fun) => {
+ let mut args = Vec::with_capacity(f.arguments.len());
+ for arg in &f.arguments {
+ let arg_expr = match &arg.arg_type {
+ Some(ArgType::Value(e)) => {
+ from_substrait_rex(e, input_schema,
extensions).await
+ }
+ _ => Err(DataFusionError::NotImplemented(
+ "Aggregated function argument non-Value type
not supported"
+ .to_string(),
+ )),
+ };
+ args.push(arg_expr?.as_ref().clone());
+ }
+ Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction {
+ fun,
+ args,
+ })))
+ }
+ ScalarFunctionType::Op(op) => {
+ if f.arguments.len() != 2 {
+ return Err(DataFusionError::NotImplemented(format!(
+ "Expect two arguments for binary operator {op:?}",
+ )));
+ }
+ let lhs = &f.arguments[0].arg_type;
+ let rhs = &f.arguments[1].arg_type;
+
+ match (lhs, rhs) {
+ (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => {
+ Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
left: Box::new(
from_substrait_rex(l, input_schema,
extensions)
.await?
@@ -819,116 +840,38 @@ pub async fn from_substrait_rex(
),
})))
}
- Ok(ScalarFunctionType::Builtin(fun)) => {
-
Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction {
- fun,
- args: vec![
- from_substrait_rex(l, input_schema,
extensions)
- .await?
- .as_ref()
- .clone(),
- from_substrait_rex(r, input_schema,
extensions)
- .await?
- .as_ref()
- .clone(),
- ],
- })))
- }
- Ok(ScalarFunctionType::Not) => {
- Err(DataFusionError::NotImplemented(
- "Not expected function type: Not".to_string(),
- ))
- }
- Err(e) => Err(e),
- }
- }
- (l, r) => Err(DataFusionError::NotImplemented(format!(
- "Invalid arguments for binary expression: {l:?} and {r:?}"
- ))),
- },
- // ScalarFunction or Expr::Not
- 1 => {
- let fun = match extensions.get(&f.function_reference) {
- Some(fname) => scalar_function_or_not(fname),
- None => Err(DataFusionError::NotImplemented(format!(
- "Function not found: function reference = {:?}",
- f.function_reference
- ))),
- };
-
- match fun {
- Ok(ScalarFunctionType::Op(_)) => {
- Err(DataFusionError::NotImplemented(
- "Not expected function type: Op".to_string(),
- ))
- }
- Ok(scalar_function_type) => {
- match &f.arguments.first().unwrap().arg_type {
- Some(ArgType::Value(e)) => {
- let expr =
- from_substrait_rex(e, input_schema,
extensions)
- .await?
- .as_ref()
- .clone();
- match scalar_function_type {
- ScalarFunctionType::Builtin(fun) =>
Ok(Arc::new(
-
Expr::ScalarFunction(expr::ScalarFunction {
- fun,
- args: vec![expr],
- }),
- )),
- ScalarFunctionType::Not => {
- Ok(Arc::new(Expr::Not(Box::new(expr))))
- }
- _ => Err(DataFusionError::NotImplemented(
- "Invalid arguments for Not expression"
- .to_string(),
- )),
- }
- }
- _ => Err(DataFusionError::NotImplemented(
- "Invalid arguments for Not
expression".to_string(),
- )),
- }
+ (l, r) => Err(DataFusionError::NotImplemented(format!(
+ "Invalid arguments for binary expression: {l:?}
and {r:?}"
+ ))),
}
- Err(e) => Err(e),
}
- }
- // ScalarFunction
- _ => {
- let fun = match extensions.get(&f.function_reference) {
- Some(fname) => BuiltinScalarFunction::from_str(fname),
- None => Err(DataFusionError::NotImplemented(format!(
- "Aggregated function not found: function reference =
{:?}",
- f.function_reference
- ))),
- };
-
- let mut args: Vec<Expr> = vec![];
- for arg in f.arguments.iter() {
+ ScalarFunctionType::Not => {
+ let arg = f.arguments.first().ok_or_else(|| {
+ DataFusionError::Substrait(
+ "expect one argument for `NOT` expr".to_string(),
+ )
+ })?;
match &arg.arg_type {
Some(ArgType::Value(e)) => {
- args.push(
- from_substrait_rex(e, input_schema, extensions)
- .await?
- .as_ref()
- .clone(),
- );
- }
- e => {
- return Err(DataFusionError::NotImplemented(format!(
- "Invalid arguments for scalar function: {e:?}"
- )))
+ let expr = from_substrait_rex(e, input_schema,
extensions)
+ .await?
+ .as_ref()
+ .clone();
+ Ok(Arc::new(Expr::Not(Box::new(expr))))
}
+ _ => Err(DataFusionError::NotImplemented(
+ "Invalid arguments for Not expression".to_string(),
+ )),
}
}
-
- Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction {
- fun: fun?,
- args,
- })))
+ ScalarFunctionType::Like => {
+ make_datafusion_like(false, f, input_schema,
extensions).await
+ }
+ ScalarFunctionType::ILike => {
+ make_datafusion_like(true, f, input_schema,
extensions).await
+ }
}
- },
+ }
Some(RexType::Literal(lit)) => {
let scalar_value = from_substrait_literal(lit)?;
Ok(Arc::new(Expr::Literal(scalar_value)))
@@ -1342,3 +1285,67 @@ fn from_substrait_null(null_type: &Type) ->
Result<ScalarValue> {
))
}
}
+
+async fn make_datafusion_like(
+ case_insensitive: bool,
+ f: &ScalarFunction,
+ input_schema: &DFSchema,
+ extensions: &HashMap<u32, &String>,
+) -> Result<Arc<Expr>> {
+ let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" };
+ if f.arguments.len() != 3 {
+ return Err(DataFusionError::NotImplemented(format!(
+ "Expect three arguments for `{fn_name}` expr"
+ )));
+ }
+
+ let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else {
+ return Err(DataFusionError::NotImplemented(
+ format!("Invalid arguments type for `{fn_name}` expr")
+ ))
+ };
+ let expr = from_substrait_rex(expr_substrait, input_schema, extensions)
+ .await?
+ .as_ref()
+ .clone();
+ let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type
else {
+ return Err(DataFusionError::NotImplemented(
+ format!("Invalid arguments type for `{fn_name}` expr")
+ ))
+ };
+ let pattern = from_substrait_rex(pattern_substrait, input_schema,
extensions)
+ .await?
+ .as_ref()
+ .clone();
+ let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type
else {
+ return Err(DataFusionError::NotImplemented(
+ format!("Invalid arguments type for `{fn_name}` expr")
+ ))
+ };
+ let escape_char_expr =
+ from_substrait_rex(escape_char_substrait, input_schema, extensions)
+ .await?
+ .as_ref()
+ .clone();
+ let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else {
+ return Err(DataFusionError::Substrait(format!(
+ "Expect Utf8 literal for escape char, but found
{escape_char_expr:?}",
+ )))
+ };
+
+ if case_insensitive {
+ Ok(Arc::new(Expr::ILike(Like {
+ negated: false,
+ expr: Box::new(expr),
+ pattern: Box::new(pattern),
+ escape_char: escape_char.map(|c| c.chars().next().unwrap()),
+ })))
+ } else {
+ Ok(Arc::new(Expr::Like(Like {
+ negated: false,
+ expr: Box::new(expr),
+ pattern: Box::new(pattern),
+ escape_char: escape_char.map(|c| c.chars().next().unwrap()),
+ })))
+ }
+}
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index 5e7ee267c4..ece1651683 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -19,6 +19,7 @@ use std::collections::HashMap;
use std::ops::Deref;
use std::sync::Arc;
+use datafusion::logical_expr::Like;
use datafusion::{
arrow::datatypes::{DataType, TimeUnit},
error::{DataFusionError, Result},
@@ -913,6 +914,36 @@ pub fn to_substrait_rex(
bounds,
))
}
+ Expr::Like(Like {
+ negated,
+ expr,
+ pattern,
+ escape_char,
+ }) => make_substrait_like_expr(
+ false,
+ *negated,
+ expr,
+ pattern,
+ *escape_char,
+ schema,
+ col_ref_offset,
+ extension_info,
+ ),
+ Expr::ILike(Like {
+ negated,
+ expr,
+ pattern,
+ escape_char,
+ }) => make_substrait_like_expr(
+ true,
+ *negated,
+ expr,
+ pattern,
+ *escape_char,
+ schema,
+ col_ref_offset,
+ extension_info,
+ ),
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported expression: {expr:?}"
))),
@@ -1130,6 +1161,71 @@ fn make_substrait_window_function(
}
}
+#[allow(deprecated)]
+#[allow(clippy::too_many_arguments)]
+fn make_substrait_like_expr(
+ ignore_case: bool,
+ negated: bool,
+ expr: &Expr,
+ pattern: &Expr,
+ escape_char: Option<char>,
+ schema: &DFSchemaRef,
+ col_ref_offset: usize,
+ extension_info: &mut (
+ Vec<extensions::SimpleExtensionDeclaration>,
+ HashMap<String, u32>,
+ ),
+) -> Result<Expression> {
+ let function_anchor = if ignore_case {
+ _register_function("ilike".to_string(), extension_info)
+ } else {
+ _register_function("like".to_string(), extension_info)
+ };
+ let expr = to_substrait_rex(expr, schema, col_ref_offset, extension_info)?;
+ let pattern = to_substrait_rex(pattern, schema, col_ref_offset,
extension_info)?;
+ let escape_char =
+ to_substrait_literal(&ScalarValue::Utf8(escape_char.map(|c|
c.to_string())))?;
+ let arguments = vec![
+ FunctionArgument {
+ arg_type: Some(ArgType::Value(expr)),
+ },
+ FunctionArgument {
+ arg_type: Some(ArgType::Value(pattern)),
+ },
+ FunctionArgument {
+ arg_type: Some(ArgType::Value(escape_char)),
+ },
+ ];
+
+ let substrait_like = Expression {
+ rex_type: Some(RexType::ScalarFunction(ScalarFunction {
+ function_reference: function_anchor,
+ arguments,
+ output_type: None,
+ args: vec![],
+ options: vec![],
+ })),
+ };
+
+ if negated {
+ let function_anchor = _register_function("not".to_string(),
extension_info);
+
+ Ok(Expression {
+ rex_type: Some(RexType::ScalarFunction(ScalarFunction {
+ function_reference: function_anchor,
+ arguments: vec![FunctionArgument {
+ arg_type: Some(ArgType::Value(substrait_like)),
+ }],
+ output_type: None,
+ args: vec![],
+ options: vec![],
+ })),
+ })
+ } else {
+ Ok(substrait_like)
+ }
+}
+
fn to_substrait_bound(bound: &WindowFrameBound) -> Bound {
match bound {
WindowFrameBound::CurrentRow => Bound {
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 1d1efb2e8d..b4a3b2cf32 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -410,6 +410,16 @@ async fn roundtrip_outer_join() -> Result<()> {
roundtrip("SELECT data.a FROM data FULL OUTER JOIN data2 ON data.a =
data2.a").await
}
+#[tokio::test]
+async fn roundtrip_like() -> Result<()> {
+ roundtrip("SELECT f FROM data WHERE f LIKE 'a%b'").await
+}
+
+#[tokio::test]
+async fn roundtrip_ilike() -> Result<()> {
+ roundtrip("SELECT f FROM data WHERE f ILIKE 'a%b'").await
+}
+
#[tokio::test]
async fn simple_intersect() -> Result<()> {
assert_expected_plan(