This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 234217e439 feat:implement sql style 'substr_index' string function 
(#8272)
234217e439 is described below

commit 234217e439ccf598d704fd7645560f04f25e8a6f
Author: Syleechan <[email protected]>
AuthorDate: Sun Nov 26 20:00:20 2023 +0800

    feat:implement sql style 'substr_index' string function (#8272)
    
    * feat:implement sql style 'substr_index' string function
    
    * code format
    
    * code format
    
    * code format
    
    * fix index bound issue
    
    * code format
    
    * code format
    
    * add args len check
    
    * add sql tests
    
    * code format
    
    * doc format
---
 datafusion/expr/src/built_in_function.rs           | 15 +++++
 datafusion/expr/src/expr_fn.rs                     |  2 +
 datafusion/physical-expr/src/functions.rs          | 23 +++++++
 .../physical-expr/src/unicode_expressions.rs       | 65 +++++++++++++++++++
 datafusion/proto/proto/datafusion.proto            |  1 +
 datafusion/proto/src/generated/pbjson.rs           |  3 +
 datafusion/proto/src/generated/prost.rs            |  3 +
 datafusion/proto/src/logical_plan/from_proto.rs    | 12 +++-
 datafusion/proto/src/logical_plan/to_proto.rs      |  1 +
 datafusion/sqllogictest/test_files/functions.slt   | 75 ++++++++++++++++++++++
 docs/source/user-guide/sql/scalar_functions.md     | 18 ++++++
 11 files changed, 215 insertions(+), 3 deletions(-)

diff --git a/datafusion/expr/src/built_in_function.rs 
b/datafusion/expr/src/built_in_function.rs
index cbf5d400ba..d920675016 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -302,6 +302,8 @@ pub enum BuiltinScalarFunction {
     OverLay,
     /// levenshtein
     Levenshtein,
+    /// substr_index
+    SubstrIndex,
 }
 
 /// Maps the sql function name to `BuiltinScalarFunction`
@@ -470,6 +472,7 @@ impl BuiltinScalarFunction {
             BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable,
             BuiltinScalarFunction::OverLay => Volatility::Immutable,
             BuiltinScalarFunction::Levenshtein => Volatility::Immutable,
+            BuiltinScalarFunction::SubstrIndex => Volatility::Immutable,
 
             // Stable builtin functions
             BuiltinScalarFunction::Now => Volatility::Stable,
@@ -773,6 +776,9 @@ impl BuiltinScalarFunction {
                     return plan_err!("The to_hex function can only accept 
integers.");
                 }
             }),
+            BuiltinScalarFunction::SubstrIndex => {
+                utf8_to_str_type(&input_expr_types[0], "substr_index")
+            }
             BuiltinScalarFunction::ToTimestamp => Ok(match 
&input_expr_types[0] {
                 Int64 => Timestamp(Second, None),
                 _ => Timestamp(Nanosecond, None),
@@ -1235,6 +1241,14 @@ impl BuiltinScalarFunction {
                 self.volatility(),
             ),
 
+            BuiltinScalarFunction::SubstrIndex => Signature::one_of(
+                vec![
+                    Exact(vec![Utf8, Utf8, Int64]),
+                    Exact(vec![LargeUtf8, LargeUtf8, Int64]),
+                ],
+                self.volatility(),
+            ),
+
             BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate 
=> {
                 Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], 
self.volatility())
             }
@@ -1486,6 +1500,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static 
[&'static str] {
         BuiltinScalarFunction::Upper => &["upper"],
         BuiltinScalarFunction::Uuid => &["uuid"],
         BuiltinScalarFunction::Levenshtein => &["levenshtein"],
+        BuiltinScalarFunction::SubstrIndex => &["substr_index", 
"substring_index"],
 
         // regex functions
         BuiltinScalarFunction::RegexpMatch => &["regexp_match"],
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 4da6857594..d2c5e5cddb 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -916,6 +916,7 @@ scalar_expr!(
 
 scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type");
 scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the 
Levenshtein distance between the two given strings");
+scalar_expr!(SubstrIndex, substr_index, string delimiter count, "Returns the 
substring from str before count occurrences of the delimiter");
 
 scalar_expr!(
     Struct,
@@ -1205,6 +1206,7 @@ mod test {
         test_nary_scalar_expr!(OverLay, overlay, string, characters, position, 
len);
         test_nary_scalar_expr!(OverLay, overlay, string, characters, position);
         test_scalar_expr!(Levenshtein, levenshtein, string1, string2);
+        test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count);
     }
 
     #[test]
diff --git a/datafusion/physical-expr/src/functions.rs 
b/datafusion/physical-expr/src/functions.rs
index 5a1a68dd21..40b21347ed 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -862,6 +862,29 @@ pub fn create_physical_fun(
                 ))),
             })
         }
+        BuiltinScalarFunction::SubstrIndex => {
+            Arc::new(|args| match args[0].data_type() {
+                DataType::Utf8 => {
+                    let func = invoke_if_unicode_expressions_feature_flag!(
+                        substr_index,
+                        i32,
+                        "substr_index"
+                    );
+                    make_scalar_function(func)(args)
+                }
+                DataType::LargeUtf8 => {
+                    let func = invoke_if_unicode_expressions_feature_flag!(
+                        substr_index,
+                        i64,
+                        "substr_index"
+                    );
+                    make_scalar_function(func)(args)
+                }
+                other => Err(DataFusionError::Internal(format!(
+                    "Unsupported data type {other:?} for function 
substr_index",
+                ))),
+            })
+        }
     })
 }
 
diff --git a/datafusion/physical-expr/src/unicode_expressions.rs 
b/datafusion/physical-expr/src/unicode_expressions.rs
index e28700a25c..f27b3c1577 100644
--- a/datafusion/physical-expr/src/unicode_expressions.rs
+++ b/datafusion/physical-expr/src/unicode_expressions.rs
@@ -455,3 +455,68 @@ pub fn translate<T: OffsetSizeTrait>(args: &[ArrayRef]) -> 
Result<ArrayRef> {
 
     Ok(Arc::new(result) as ArrayRef)
 }
+
+/// Returns the substring from str before count occurrences of the delimiter 
delim. If count is positive, everything to the left of the final delimiter 
(counting from the left) is returned. If count is negative, everything to the 
right of the final delimiter (counting from the right) is returned.
+/// SUBSTRING_INDEX('www.apache.org', '.', 1) = www
+/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache
+/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org
+/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org
+pub fn substr_index<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> 
{
+    if args.len() != 3 {
+        return internal_err!(
+            "substr_index was called with {} arguments. It requires 3.",
+            args.len()
+        );
+    }
+
+    let string_array = as_generic_string_array::<T>(&args[0])?;
+    let delimiter_array = as_generic_string_array::<T>(&args[1])?;
+    let count_array = as_int64_array(&args[2])?;
+
+    let result = string_array
+        .iter()
+        .zip(delimiter_array.iter())
+        .zip(count_array.iter())
+        .map(|((string, delimiter), n)| match (string, delimiter, n) {
+            (Some(string), Some(delimiter), Some(n)) => {
+                let mut res = String::new();
+                match n {
+                    0 => {
+                        "".to_string();
+                    }
+                    _other => {
+                        if n > 0 {
+                            let idx = string
+                                .split(delimiter)
+                                .take(n as usize)
+                                .fold(0, |len, x| len + x.len() + 
delimiter.len())
+                                - delimiter.len();
+                            res.push_str(if idx >= string.len() {
+                                string
+                            } else {
+                                &string[..idx]
+                            });
+                        } else {
+                            let idx = (string.split(delimiter).take((-n) as 
usize).fold(
+                                string.len() as isize,
+                                |len, x| {
+                                    len - x.len() as isize - delimiter.len() 
as isize
+                                },
+                            ) + delimiter.len() as isize)
+                                as usize;
+                            res.push_str(if idx >= string.len() {
+                                string
+                            } else {
+                                &string[idx..]
+                            });
+                        }
+                    }
+                }
+                Some(res)
+            }
+            _ => None,
+        })
+        .collect::<GenericStringArray<T>>();
+
+    Ok(Arc::new(result) as ArrayRef)
+}
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index d43d19f858..5c33b10f13 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -641,6 +641,7 @@ enum ScalarFunction {
   ArrayExcept = 123;
   ArrayPopFront = 124;
   Levenshtein = 125;
+  SubstrIndex = 126;
 }
 
 message ScalarFunctionNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 133bbbee89..598719dc8a 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -20863,6 +20863,7 @@ impl serde::Serialize for ScalarFunction {
             Self::ArrayExcept => "ArrayExcept",
             Self::ArrayPopFront => "ArrayPopFront",
             Self::Levenshtein => "Levenshtein",
+            Self::SubstrIndex => "SubstrIndex",
         };
         serializer.serialize_str(variant)
     }
@@ -21000,6 +21001,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
             "ArrayExcept",
             "ArrayPopFront",
             "Levenshtein",
+            "SubstrIndex",
         ];
 
         struct GeneratedVisitor;
@@ -21166,6 +21168,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
                     "ArrayExcept" => Ok(ScalarFunction::ArrayExcept),
                     "ArrayPopFront" => Ok(ScalarFunction::ArrayPopFront),
                     "Levenshtein" => Ok(ScalarFunction::Levenshtein),
+                    "SubstrIndex" => Ok(ScalarFunction::SubstrIndex),
                     _ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
                 }
             }
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 503c4b6c73..e79a17fc5c 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2594,6 +2594,7 @@ pub enum ScalarFunction {
     ArrayExcept = 123,
     ArrayPopFront = 124,
     Levenshtein = 125,
+    SubstrIndex = 126,
 }
 impl ScalarFunction {
     /// String value of the enum field names used in the ProtoBuf definition.
@@ -2728,6 +2729,7 @@ impl ScalarFunction {
             ScalarFunction::ArrayExcept => "ArrayExcept",
             ScalarFunction::ArrayPopFront => "ArrayPopFront",
             ScalarFunction::Levenshtein => "Levenshtein",
+            ScalarFunction::SubstrIndex => "SubstrIndex",
         }
     }
     /// Creates an enum from field names used in the ProtoBuf definition.
@@ -2859,6 +2861,7 @@ impl ScalarFunction {
             "ArrayExcept" => Some(Self::ArrayExcept),
             "ArrayPopFront" => Some(Self::ArrayPopFront),
             "Levenshtein" => Some(Self::Levenshtein),
+            "SubstrIndex" => Some(Self::SubstrIndex),
             _ => None,
         }
     }
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index d4a64287b0..b2455d5a0d 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -55,9 +55,9 @@ use datafusion_expr::{
     lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, 
power,
     radians, random, regexp_match, regexp_replace, repeat, replace, reverse, 
right,
     round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, 
split_part,
-    sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substring, 
tan, tanh,
-    to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos,
-    to_timestamp_seconds, translate, trim, trunc, upper, uuid,
+    sqrt, starts_with, string_to_array, strpos, struct_fun, substr, 
substr_index,
+    substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis,
+    to_timestamp_nanos, to_timestamp_seconds, translate, trim, trunc, upper, 
uuid,
     window_frame::regularize,
     AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, 
BuiltinScalarFunction,
     Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet,
@@ -551,6 +551,7 @@ impl From<&protobuf::ScalarFunction> for 
BuiltinScalarFunction {
             ScalarFunction::ArrowTypeof => Self::ArrowTypeof,
             ScalarFunction::OverLay => Self::OverLay,
             ScalarFunction::Levenshtein => Self::Levenshtein,
+            ScalarFunction::SubstrIndex => Self::SubstrIndex,
         }
     }
 }
@@ -1716,6 +1717,11 @@ pub fn parse_expr(
                         .map(|expr| parse_expr(expr, registry))
                         .collect::<Result<Vec<_>, _>>()?,
                 )),
+                ScalarFunction::SubstrIndex => Ok(substr_index(
+                    parse_expr(&args[0], registry)?,
+                    parse_expr(&args[1], registry)?,
+                    parse_expr(&args[2], registry)?,
+                )),
                 ScalarFunction::StructFun => {
                     Ok(struct_fun(parse_expr(&args[0], registry)?))
                 }
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs 
b/datafusion/proto/src/logical_plan/to_proto.rs
index 508cde98ae..9be4a532bb 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1583,6 +1583,7 @@ impl TryFrom<&BuiltinScalarFunction> for 
protobuf::ScalarFunction {
             BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof,
             BuiltinScalarFunction::OverLay => Self::OverLay,
             BuiltinScalarFunction::Levenshtein => Self::Levenshtein,
+            BuiltinScalarFunction::SubstrIndex => Self::SubstrIndex,
         };
 
         Ok(scalar_function)
diff --git a/datafusion/sqllogictest/test_files/functions.slt 
b/datafusion/sqllogictest/test_files/functions.slt
index 9c8bb2c5f8..91072a49cd 100644
--- a/datafusion/sqllogictest/test_files/functions.slt
+++ b/datafusion/sqllogictest/test_files/functions.slt
@@ -877,3 +877,78 @@ query ?
 SELECT levenshtein(NULL, NULL)
 ----
 NULL
+
+query T
+SELECT substr_index('www.apache.org', '.', 1)
+----
+www
+
+query T
+SELECT substr_index('www.apache.org', '.', 2)
+----
+www.apache
+
+query T
+SELECT substr_index('www.apache.org', '.', -1)
+----
+org
+
+query T
+SELECT substr_index('www.apache.org', '.', -2)
+----
+apache.org
+
+query T
+SELECT substr_index('www.apache.org', 'ac', 1)
+----
+www.ap
+
+query T
+SELECT substr_index('www.apache.org', 'ac', -1)
+----
+he.org
+
+query T
+SELECT substr_index('www.apache.org', 'ac', 2)
+----
+www.apache.org
+
+query T
+SELECT substr_index('www.apache.org', 'ac', -2)
+----
+www.apache.org
+
+query ?
+SELECT substr_index(NULL, 'ac', 1)
+----
+NULL
+
+query T
+SELECT substr_index('www.apache.org', NULL, 1)
+----
+NULL
+
+query T
+SELECT substr_index('www.apache.org', 'ac', NULL)
+----
+NULL
+
+query T
+SELECT substr_index('', 'ac', 1)
+----
+(empty)
+
+query T
+SELECT substr_index('www.apache.org', '', 1)
+----
+(empty)
+
+query T
+SELECT substr_index('www.apache.org', 'ac', 0)
+----
+(empty)
+
+query ?
+SELECT substr_index(NULL, NULL, NULL)
+----
+NULL
diff --git a/docs/source/user-guide/sql/scalar_functions.md 
b/docs/source/user-guide/sql/scalar_functions.md
index eda46ef8a7..e7ebbc9f1f 100644
--- a/docs/source/user-guide/sql/scalar_functions.md
+++ b/docs/source/user-guide/sql/scalar_functions.md
@@ -637,6 +637,7 @@ nullif(expression1, expression2)
 - [uuid](#uuid)
 - [overlay](#overlay)
 - [levenshtein](#levenshtein)
+- [substr_index](#substr_index)
 
 ### `ascii`
 
@@ -1152,6 +1153,23 @@ levenshtein(str1, str2)
 - **str1**: String expression to compute Levenshtein distance with str2.
 - **str2**: String expression to compute Levenshtein distance with str1.
 
+### `substr_index`
+
+Returns the substring from str before count occurrences of the delimiter delim.
+If count is positive, everything to the left of the final delimiter (counting 
from the left) is returned.
+If count is negative, everything to the right of the final delimiter (counting 
from the right) is returned.
+For example, `substr_index('www.apache.org', '.', 1) = www`, 
`substr_index('www.apache.org', '.', -1) = org`
+
+```
+substr_index(str, delim, count)
+```
+
+#### Arguments
+
+- **str**: String expression to operate on.
+- **delim**: the string to find in str to split str.
+- **count**: The number of times to search for the delimiter. Can be both a 
positive or negative number.
+
 ## Binary String Functions
 
 - [decode](#decode)

Reply via email to