This is an automated email from the ASF dual-hosted git repository.

wayne pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 16303ada9b feat: implement substrait for LIKE/ILIKE expr (#6840)
16303ada9b is described below

commit 16303ada9bda90c89b28c0d6d3600782967da03a
Author: Ruihang Xia <[email protected]>
AuthorDate: Fri Jul 14 11:31:12 2023 +0800

    feat: implement substrait for LIKE/ILIKE expr (#6840)
    
    * feat: implement substrait for LIKE/ILIKE expr
    
    Signed-off-by: Ruihang Xia <[email protected]>
    
    * fix clippy
    
    Signed-off-by: Ruihang Xia <[email protected]>
    
    * Apply suggestions from code review
    
    Co-authored-by: Nuttiiya Seekhao 
<[email protected]>
    
    * Update datafusion/substrait/src/logical_plan/consumer.rs
    
    Co-authored-by: Nuttiiya Seekhao 
<[email protected]>
    
    * style: rename function
    
    Signed-off-by: Ruihang Xia <[email protected]>
    
    * apply CR sugg.
    
    Signed-off-by: Ruihang Xia <[email protected]>
    
    ---------
    
    Signed-off-by: Ruihang Xia <[email protected]>
    Co-authored-by: Nuttiiya Seekhao 
<[email protected]>
---
 datafusion/substrait/src/logical_plan/consumer.rs  | 279 +++++++++++----------
 datafusion/substrait/src/logical_plan/producer.rs  |  96 +++++++
 .../tests/cases/roundtrip_logical_plan.rs          |  10 +
 3 files changed, 249 insertions(+), 136 deletions(-)

diff --git a/datafusion/substrait/src/logical_plan/consumer.rs 
b/datafusion/substrait/src/logical_plan/consumer.rs
index dc06b64a9e..7b54bea493 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -23,7 +23,7 @@ use datafusion::logical_expr::{
     BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator,
 };
 use datafusion::logical_expr::{expr, Cast, WindowFrameBound, WindowFrameUnits};
-use datafusion::logical_expr::{Extension, LogicalPlanBuilder};
+use datafusion::logical_expr::{Extension, Like, LogicalPlanBuilder};
 use datafusion::prelude::JoinType;
 use datafusion::sql::TableReference;
 use datafusion::{
@@ -32,7 +32,7 @@ use datafusion::{
     prelude::{Column, SessionContext},
     scalar::ScalarValue,
 };
-use substrait::proto::expression::Literal;
+use substrait::proto::expression::{Literal, ScalarFunction};
 use substrait::proto::{
     aggregate_function::AggregationInvocation,
     expression::{
@@ -67,8 +67,12 @@ use crate::variation_const::{
 enum ScalarFunctionType {
     Builtin(BuiltinScalarFunction),
     Op(Operator),
-    // logical negation
+    /// [Expr::Not]
     Not,
+    /// [Expr::Like] Used for filtering rows based on the given wildcard 
pattern. Case sensitive
+    Like,
+    /// [Expr::ILike] Case insensitive operator counterpart of `Like`
+    ILike,
 }
 
 pub fn name_to_op(name: &str) -> Result<Operator> {
@@ -104,7 +108,7 @@ pub fn name_to_op(name: &str) -> Result<Operator> {
     }
 }
 
-fn name_to_op_or_scalar_function(name: &str) -> Result<ScalarFunctionType> {
+fn scalar_function_type_from_str(name: &str) -> Result<ScalarFunctionType> {
     if let Ok(op) = name_to_op(name) {
         return Ok(ScalarFunctionType::Op(op));
     }
@@ -113,23 +117,14 @@ fn name_to_op_or_scalar_function(name: &str) -> 
Result<ScalarFunctionType> {
         return Ok(ScalarFunctionType::Builtin(fun));
     }
 
-    Err(DataFusionError::NotImplemented(format!(
-        "Unsupported function name: {name:?}"
-    )))
-}
-
-fn scalar_function_or_not(name: &str) -> Result<ScalarFunctionType> {
-    if let Ok(fun) = BuiltinScalarFunction::from_str(name) {
-        return Ok(ScalarFunctionType::Builtin(fun));
-    }
-
-    if name == "not" {
-        return Ok(ScalarFunctionType::Not);
+    match name {
+        "not" => Ok(ScalarFunctionType::Not),
+        "like" => Ok(ScalarFunctionType::Like),
+        "ilike" => Ok(ScalarFunctionType::ILike),
+        others => Err(DataFusionError::NotImplemented(format!(
+            "Unsupported function name: {others:?}"
+        ))),
     }
-
-    Err(DataFusionError::NotImplemented(format!(
-        "Unsupported function name: {name:?}"
-    )))
 }
 
 /// Convert Substrait Plan to DataFusion DataFrame
@@ -790,20 +785,46 @@ pub async fn from_substrait_rex(
                 else_expr,
             })))
         }
-        Some(RexType::ScalarFunction(f)) => match f.arguments.len() {
-            // BinaryExpr or ScalarFunction
-            2 => match (&f.arguments[0].arg_type, &f.arguments[1].arg_type) {
-                (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => {
-                    let op_or_fun = match 
extensions.get(&f.function_reference) {
-                        Some(fname) => name_to_op_or_scalar_function(fname),
-                        None => Err(DataFusionError::NotImplemented(format!(
-                            "Aggregated function not found: function reference 
= {:?}",
-                            f.function_reference
-                        ))),
-                    };
-                    match op_or_fun {
-                        Ok(ScalarFunctionType::Op(op)) => {
-                            return Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
+        Some(RexType::ScalarFunction(f)) => {
+            let fn_name = extensions.get(&f.function_reference).ok_or_else(|| {
+                DataFusionError::NotImplemented(format!(
+                    "Aggregated function not found: function reference = {:?}",
+                    f.function_reference
+                ))
+            })?;
+            let fn_type = scalar_function_type_from_str(fn_name)?;
+            match fn_type {
+                ScalarFunctionType::Builtin(fun) => {
+                    let mut args = Vec::with_capacity(f.arguments.len());
+                    for arg in &f.arguments {
+                        let arg_expr = match &arg.arg_type {
+                            Some(ArgType::Value(e)) => {
+                                from_substrait_rex(e, input_schema, 
extensions).await
+                            }
+                            _ => Err(DataFusionError::NotImplemented(
+                                "Aggregated function argument non-Value type 
not supported"
+                                    .to_string(),
+                            )),
+                        };
+                        args.push(arg_expr?.as_ref().clone());
+                    }
+                    Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction {
+                        fun,
+                        args,
+                    })))
+                }
+                ScalarFunctionType::Op(op) => {
+                    if f.arguments.len() != 2 {
+                        return Err(DataFusionError::NotImplemented(format!(
+                            "Expect two arguments for binary operator {op:?}",
+                        )));
+                    }
+                    let lhs = &f.arguments[0].arg_type;
+                    let rhs = &f.arguments[1].arg_type;
+
+                    match (lhs, rhs) {
+                        (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => {
+                            Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
                                 left: Box::new(
                                     from_substrait_rex(l, input_schema, 
extensions)
                                         .await?
@@ -819,116 +840,38 @@ pub async fn from_substrait_rex(
                                 ),
                             })))
                         }
-                        Ok(ScalarFunctionType::Builtin(fun)) => {
-                            
Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction {
-                                fun,
-                                args: vec![
-                                    from_substrait_rex(l, input_schema, 
extensions)
-                                        .await?
-                                        .as_ref()
-                                        .clone(),
-                                    from_substrait_rex(r, input_schema, 
extensions)
-                                        .await?
-                                        .as_ref()
-                                        .clone(),
-                                ],
-                            })))
-                        }
-                        Ok(ScalarFunctionType::Not) => {
-                            Err(DataFusionError::NotImplemented(
-                                "Not expected function type: Not".to_string(),
-                            ))
-                        }
-                        Err(e) => Err(e),
-                    }
-                }
-                (l, r) => Err(DataFusionError::NotImplemented(format!(
-                    "Invalid arguments for binary expression: {l:?} and {r:?}"
-                ))),
-            },
-            // ScalarFunction or Expr::Not
-            1 => {
-                let fun = match extensions.get(&f.function_reference) {
-                    Some(fname) => scalar_function_or_not(fname),
-                    None => Err(DataFusionError::NotImplemented(format!(
-                        "Function not found: function reference = {:?}",
-                        f.function_reference
-                    ))),
-                };
-
-                match fun {
-                    Ok(ScalarFunctionType::Op(_)) => {
-                        Err(DataFusionError::NotImplemented(
-                            "Not expected function type: Op".to_string(),
-                        ))
-                    }
-                    Ok(scalar_function_type) => {
-                        match &f.arguments.first().unwrap().arg_type {
-                            Some(ArgType::Value(e)) => {
-                                let expr =
-                                    from_substrait_rex(e, input_schema, 
extensions)
-                                        .await?
-                                        .as_ref()
-                                        .clone();
-                                match scalar_function_type {
-                                    ScalarFunctionType::Builtin(fun) => 
Ok(Arc::new(
-                                        
Expr::ScalarFunction(expr::ScalarFunction {
-                                            fun,
-                                            args: vec![expr],
-                                        }),
-                                    )),
-                                    ScalarFunctionType::Not => {
-                                        Ok(Arc::new(Expr::Not(Box::new(expr))))
-                                    }
-                                    _ => Err(DataFusionError::NotImplemented(
-                                        "Invalid arguments for Not expression"
-                                            .to_string(),
-                                    )),
-                                }
-                            }
-                            _ => Err(DataFusionError::NotImplemented(
-                                "Invalid arguments for Not 
expression".to_string(),
-                            )),
-                        }
+                        (l, r) => Err(DataFusionError::NotImplemented(format!(
+                            "Invalid arguments for binary expression: {l:?} 
and {r:?}"
+                        ))),
                     }
-                    Err(e) => Err(e),
                 }
-            }
-            // ScalarFunction
-            _ => {
-                let fun = match extensions.get(&f.function_reference) {
-                    Some(fname) => BuiltinScalarFunction::from_str(fname),
-                    None => Err(DataFusionError::NotImplemented(format!(
-                        "Aggregated function not found: function reference = 
{:?}",
-                        f.function_reference
-                    ))),
-                };
-
-                let mut args: Vec<Expr> = vec![];
-                for arg in f.arguments.iter() {
+                ScalarFunctionType::Not => {
+                    let arg = f.arguments.first().ok_or_else(|| {
+                        DataFusionError::Substrait(
+                            "expect one argument for `NOT` expr".to_string(),
+                        )
+                    })?;
                     match &arg.arg_type {
                         Some(ArgType::Value(e)) => {
-                            args.push(
-                                from_substrait_rex(e, input_schema, extensions)
-                                    .await?
-                                    .as_ref()
-                                    .clone(),
-                            );
-                        }
-                        e => {
-                            return Err(DataFusionError::NotImplemented(format!(
-                                "Invalid arguments for scalar function: {e:?}"
-                            )))
+                            let expr = from_substrait_rex(e, input_schema, 
extensions)
+                                .await?
+                                .as_ref()
+                                .clone();
+                            Ok(Arc::new(Expr::Not(Box::new(expr))))
                         }
+                        _ => Err(DataFusionError::NotImplemented(
+                            "Invalid arguments for Not expression".to_string(),
+                        )),
                     }
                 }
-
-                Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction {
-                    fun: fun?,
-                    args,
-                })))
+                ScalarFunctionType::Like => {
+                    make_datafusion_like(false, f, input_schema, 
extensions).await
+                }
+                ScalarFunctionType::ILike => {
+                    make_datafusion_like(true, f, input_schema, 
extensions).await
+                }
             }
-        },
+        }
         Some(RexType::Literal(lit)) => {
             let scalar_value = from_substrait_literal(lit)?;
             Ok(Arc::new(Expr::Literal(scalar_value)))
@@ -1342,3 +1285,67 @@ fn from_substrait_null(null_type: &Type) -> 
Result<ScalarValue> {
         ))
     }
 }
+
+async fn make_datafusion_like(
+    case_insensitive: bool,
+    f: &ScalarFunction,
+    input_schema: &DFSchema,
+    extensions: &HashMap<u32, &String>,
+) -> Result<Arc<Expr>> {
+    let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" };
+    if f.arguments.len() != 3 {
+        return Err(DataFusionError::NotImplemented(format!(
+            "Expect three arguments for `{fn_name}` expr"
+        )));
+    }
+
+    let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else {
+        return Err(DataFusionError::NotImplemented(
+            format!("Invalid arguments type for `{fn_name}` expr")
+        ))
+    };
+    let expr = from_substrait_rex(expr_substrait, input_schema, extensions)
+        .await?
+        .as_ref()
+        .clone();
+    let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type 
else {
+        return Err(DataFusionError::NotImplemented(
+            format!("Invalid arguments type for `{fn_name}` expr")
+        ))
+    };
+    let pattern = from_substrait_rex(pattern_substrait, input_schema, 
extensions)
+        .await?
+        .as_ref()
+        .clone();
+    let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type 
else {
+        return Err(DataFusionError::NotImplemented(
+            format!("Invalid arguments type for `{fn_name}` expr")
+        ))
+    };
+    let escape_char_expr =
+        from_substrait_rex(escape_char_substrait, input_schema, extensions)
+            .await?
+            .as_ref()
+            .clone();
+    let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else {
+        return Err(DataFusionError::Substrait(format!(
+            "Expect Utf8 literal for escape char, but found 
{escape_char_expr:?}",
+        )))
+    };
+
+    if case_insensitive {
+        Ok(Arc::new(Expr::ILike(Like {
+            negated: false,
+            expr: Box::new(expr),
+            pattern: Box::new(pattern),
+            escape_char: escape_char.map(|c| c.chars().next().unwrap()),
+        })))
+    } else {
+        Ok(Arc::new(Expr::Like(Like {
+            negated: false,
+            expr: Box::new(expr),
+            pattern: Box::new(pattern),
+            escape_char: escape_char.map(|c| c.chars().next().unwrap()),
+        })))
+    }
+}
diff --git a/datafusion/substrait/src/logical_plan/producer.rs 
b/datafusion/substrait/src/logical_plan/producer.rs
index 5e7ee267c4..ece1651683 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -19,6 +19,7 @@ use std::collections::HashMap;
 use std::ops::Deref;
 use std::sync::Arc;
 
+use datafusion::logical_expr::Like;
 use datafusion::{
     arrow::datatypes::{DataType, TimeUnit},
     error::{DataFusionError, Result},
@@ -913,6 +914,36 @@ pub fn to_substrait_rex(
                 bounds,
             ))
         }
+        Expr::Like(Like {
+            negated,
+            expr,
+            pattern,
+            escape_char,
+        }) => make_substrait_like_expr(
+            false,
+            *negated,
+            expr,
+            pattern,
+            *escape_char,
+            schema,
+            col_ref_offset,
+            extension_info,
+        ),
+        Expr::ILike(Like {
+            negated,
+            expr,
+            pattern,
+            escape_char,
+        }) => make_substrait_like_expr(
+            true,
+            *negated,
+            expr,
+            pattern,
+            *escape_char,
+            schema,
+            col_ref_offset,
+            extension_info,
+        ),
         _ => Err(DataFusionError::NotImplemented(format!(
             "Unsupported expression: {expr:?}"
         ))),
@@ -1130,6 +1161,71 @@ fn make_substrait_window_function(
     }
 }
 
+#[allow(deprecated)]
+#[allow(clippy::too_many_arguments)]
+fn make_substrait_like_expr(
+    ignore_case: bool,
+    negated: bool,
+    expr: &Expr,
+    pattern: &Expr,
+    escape_char: Option<char>,
+    schema: &DFSchemaRef,
+    col_ref_offset: usize,
+    extension_info: &mut (
+        Vec<extensions::SimpleExtensionDeclaration>,
+        HashMap<String, u32>,
+    ),
+) -> Result<Expression> {
+    let function_anchor = if ignore_case {
+        _register_function("ilike".to_string(), extension_info)
+    } else {
+        _register_function("like".to_string(), extension_info)
+    };
+    let expr = to_substrait_rex(expr, schema, col_ref_offset, extension_info)?;
+    let pattern = to_substrait_rex(pattern, schema, col_ref_offset, 
extension_info)?;
+    let escape_char =
+        to_substrait_literal(&ScalarValue::Utf8(escape_char.map(|c| 
c.to_string())))?;
+    let arguments = vec![
+        FunctionArgument {
+            arg_type: Some(ArgType::Value(expr)),
+        },
+        FunctionArgument {
+            arg_type: Some(ArgType::Value(pattern)),
+        },
+        FunctionArgument {
+            arg_type: Some(ArgType::Value(escape_char)),
+        },
+    ];
+
+    let substrait_like = Expression {
+        rex_type: Some(RexType::ScalarFunction(ScalarFunction {
+            function_reference: function_anchor,
+            arguments,
+            output_type: None,
+            args: vec![],
+            options: vec![],
+        })),
+    };
+
+    if negated {
+        let function_anchor = _register_function("not".to_string(), 
extension_info);
+
+        Ok(Expression {
+            rex_type: Some(RexType::ScalarFunction(ScalarFunction {
+                function_reference: function_anchor,
+                arguments: vec![FunctionArgument {
+                    arg_type: Some(ArgType::Value(substrait_like)),
+                }],
+                output_type: None,
+                args: vec![],
+                options: vec![],
+            })),
+        })
+    } else {
+        Ok(substrait_like)
+    }
+}
+
 fn to_substrait_bound(bound: &WindowFrameBound) -> Bound {
     match bound {
         WindowFrameBound::CurrentRow => Bound {
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 1d1efb2e8d..b4a3b2cf32 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -410,6 +410,16 @@ async fn roundtrip_outer_join() -> Result<()> {
     roundtrip("SELECT data.a FROM data FULL OUTER JOIN data2 ON data.a = 
data2.a").await
 }
 
+#[tokio::test]
+async fn roundtrip_like() -> Result<()> {
+    roundtrip("SELECT f FROM data WHERE f LIKE 'a%b'").await
+}
+
+#[tokio::test]
+async fn roundtrip_ilike() -> Result<()> {
+    roundtrip("SELECT f FROM data WHERE f ILIKE 'a%b'").await
+}
+
 #[tokio::test]
 async fn simple_intersect() -> Result<()> {
     assert_expected_plan(

Reply via email to