This is an automated email from the ASF dual-hosted git repository.

yangjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 4535551120 feat:implement postgres style 'overlay' string function 
(#8117)
4535551120 is described below

commit 4535551120dd4c31c160a35851b9e4a33514a44f
Author: Syleechan <[email protected]>
AuthorDate: Tue Nov 14 15:31:25 2023 +0800

    feat:implement postgres style 'overlay' string function (#8117)
    
    * feat:implement posgres style 'overlay' string function
    
    * code format
    
    * code format
    
    * code format
    
    * code format
    
    * add sql slt test
    
    * fix modify other case issue
    
    * add test expr
    
    * add annotation
    
    * add overlay function sql reference doc
    
    * add sql case and format doc
---
 datafusion/expr/src/built_in_function.rs           |  18 +++-
 datafusion/expr/src/expr_fn.rs                     |   7 ++
 datafusion/physical-expr/src/functions.rs          |  11 +++
 datafusion/physical-expr/src/string_expressions.rs | 108 +++++++++++++++++++++
 datafusion/proto/proto/datafusion.proto            |   1 +
 datafusion/proto/src/generated/pbjson.rs           |   3 +
 datafusion/proto/src/generated/prost.rs            |   3 +
 datafusion/proto/src/logical_plan/from_proto.rs    |  15 ++-
 datafusion/proto/src/logical_plan/to_proto.rs      |   1 +
 datafusion/sql/src/expr/mod.rs                     |  40 +++++++-
 datafusion/sqllogictest/test_files/functions.slt   |  42 ++++++++
 docs/source/user-guide/sql/scalar_functions.md     |  17 ++++
 12 files changed, 260 insertions(+), 6 deletions(-)

diff --git a/datafusion/expr/src/built_in_function.rs 
b/datafusion/expr/src/built_in_function.rs
index 0d2c1f2e3c..77c64128e1 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -292,6 +292,8 @@ pub enum BuiltinScalarFunction {
     RegexpMatch,
     /// arrow_typeof
     ArrowTypeof,
+    /// overlay
+    OverLay,
 }
 
 /// Maps the sql function name to `BuiltinScalarFunction`
@@ -455,6 +457,7 @@ impl BuiltinScalarFunction {
             BuiltinScalarFunction::Struct => Volatility::Immutable,
             BuiltinScalarFunction::FromUnixtime => Volatility::Immutable,
             BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable,
+            BuiltinScalarFunction::OverLay => Volatility::Immutable,
 
             // Stable builtin functions
             BuiltinScalarFunction::Now => Volatility::Stable,
@@ -812,6 +815,10 @@ impl BuiltinScalarFunction {
 
             BuiltinScalarFunction::Abs => Ok(input_expr_types[0].clone()),
 
+            BuiltinScalarFunction::OverLay => {
+                utf8_to_str_type(&input_expr_types[0], "overlay")
+            }
+
             BuiltinScalarFunction::Acos
             | BuiltinScalarFunction::Asin
             | BuiltinScalarFunction::Atan
@@ -1258,7 +1265,15 @@ impl BuiltinScalarFunction {
             }
             BuiltinScalarFunction::ArrowTypeof => Signature::any(1, 
self.volatility()),
             BuiltinScalarFunction::Abs => Signature::any(1, self.volatility()),
-
+            BuiltinScalarFunction::OverLay => Signature::one_of(
+                vec![
+                    Exact(vec![Utf8, Utf8, Int64, Int64]),
+                    Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]),
+                    Exact(vec![Utf8, Utf8, Int64]),
+                    Exact(vec![LargeUtf8, LargeUtf8, Int64]),
+                ],
+                self.volatility(),
+            ),
             BuiltinScalarFunction::Acos
             | BuiltinScalarFunction::Asin
             | BuiltinScalarFunction::Atan
@@ -1517,6 +1532,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static 
[&'static str] {
         BuiltinScalarFunction::Cardinality => &["cardinality"],
         BuiltinScalarFunction::MakeArray => &["make_array", "make_list"],
         BuiltinScalarFunction::ArrayIntersect => &["array_intersect", 
"list_intersect"],
+        BuiltinScalarFunction::OverLay => &["overlay"],
 
         // struct functions
         BuiltinScalarFunction::Struct => &["struct"],
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 0d920beb41..91674cc092 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -838,6 +838,11 @@ nary_scalar_expr!(
     "concatenates several strings, placing a seperator between each one"
 );
 nary_scalar_expr!(Concat, concat_expr, "concatenates several strings");
+nary_scalar_expr!(
+    OverLay,
+    overlay,
+    "replace the substring of string that starts at the start'th character and 
extends for count characters with new substring"
+);
 
 // date functions
 scalar_expr!(DatePart, date_part, part date, "extracts a subfield from the 
date");
@@ -1174,6 +1179,8 @@ mod test {
         test_nary_scalar_expr!(MakeArray, array, input);
 
         test_unary_scalar_expr!(ArrowTypeof, arrow_typeof);
+        test_nary_scalar_expr!(OverLay, overlay, string, characters, position, 
len);
+        test_nary_scalar_expr!(OverLay, overlay, string, characters, position);
     }
 
     #[test]
diff --git a/datafusion/physical-expr/src/functions.rs 
b/datafusion/physical-expr/src/functions.rs
index 80c0eaf054..7f8921e86c 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -829,6 +829,17 @@ pub fn create_physical_fun(
                 "{input_data_type}"
             )))))
         }),
+        BuiltinScalarFunction::OverLay => Arc::new(|args| match 
args[0].data_type() {
+            DataType::Utf8 => {
+                make_scalar_function(string_expressions::overlay::<i32>)(args)
+            }
+            DataType::LargeUtf8 => {
+                make_scalar_function(string_expressions::overlay::<i64>)(args)
+            }
+            other => Err(DataFusionError::Internal(format!(
+                "Unsupported data type {other:?} for function overlay",
+            ))),
+        }),
     })
 }
 
diff --git a/datafusion/physical-expr/src/string_expressions.rs 
b/datafusion/physical-expr/src/string_expressions.rs
index e6a3d5c331..7e954fdcfd 100644
--- a/datafusion/physical-expr/src/string_expressions.rs
+++ b/datafusion/physical-expr/src/string_expressions.rs
@@ -553,11 +553,102 @@ pub fn uuid(args: &[ColumnarValue]) -> 
Result<ColumnarValue> {
     Ok(ColumnarValue::Array(Arc::new(array)))
 }
 
+/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2)
+/// Replaces a substring of string1 with string2 starting at the integer bit
+/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas
+/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, 
str2's len is instead
+pub fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
+    match args.len() {
+        3 => {
+            let string_array = as_generic_string_array::<T>(&args[0])?;
+            let characters_array = as_generic_string_array::<T>(&args[1])?;
+            let pos_num = as_int64_array(&args[2])?;
+
+            let result = string_array
+                .iter()
+                .zip(characters_array.iter())
+                .zip(pos_num.iter())
+                .map(|((string, characters), start_pos)| {
+                    match (string, characters, start_pos) {
+                        (Some(string), Some(characters), Some(start_pos)) => {
+                            let string_len = string.chars().count();
+                            let characters_len = characters.chars().count();
+                            let replace_len = characters_len as i64;
+                            let mut res =
+                                
String::with_capacity(string_len.max(characters_len));
+
+                            //as sql replace index start from 1 while string 
index start from 0
+                            if start_pos > 1 && start_pos - 1 < string_len as 
i64 {
+                                let start = (start_pos - 1) as usize;
+                                res.push_str(&string[..start]);
+                            }
+                            res.push_str(characters);
+                            // if start + replace_len - 1 >= string_length, 
just to string end
+                            if start_pos + replace_len - 1 < string_len as i64 
{
+                                let end = (start_pos + replace_len - 1) as 
usize;
+                                res.push_str(&string[end..]);
+                            }
+                            Ok(Some(res))
+                        }
+                        _ => Ok(None),
+                    }
+                })
+                .collect::<Result<GenericStringArray<T>>>()?;
+            Ok(Arc::new(result) as ArrayRef)
+        }
+        4 => {
+            let string_array = as_generic_string_array::<T>(&args[0])?;
+            let characters_array = as_generic_string_array::<T>(&args[1])?;
+            let pos_num = as_int64_array(&args[2])?;
+            let len_num = as_int64_array(&args[3])?;
+
+            let result = string_array
+                .iter()
+                .zip(characters_array.iter())
+                .zip(pos_num.iter())
+                .zip(len_num.iter())
+                .map(|(((string, characters), start_pos), len)| {
+                    match (string, characters, start_pos, len) {
+                        (Some(string), Some(characters), Some(start_pos), 
Some(len)) => {
+                            let string_len = string.chars().count();
+                            let characters_len = characters.chars().count();
+                            let replace_len = len.min(string_len as i64);
+                            let mut res =
+                                
String::with_capacity(string_len.max(characters_len));
+
+                            //as sql replace index start from 1 while string 
index start from 0
+                            if start_pos > 1 && start_pos - 1 < string_len as 
i64 {
+                                let start = (start_pos - 1) as usize;
+                                res.push_str(&string[..start]);
+                            }
+                            res.push_str(characters);
+                            // if start + replace_len - 1 >= string_length, 
just to string end
+                            if start_pos + replace_len - 1 < string_len as i64 
{
+                                let end = (start_pos + replace_len - 1) as 
usize;
+                                res.push_str(&string[end..]);
+                            }
+                            Ok(Some(res))
+                        }
+                        _ => Ok(None),
+                    }
+                })
+                .collect::<Result<GenericStringArray<T>>>()?;
+            Ok(Arc::new(result) as ArrayRef)
+        }
+        other => {
+            internal_err!(
+                "overlay was called with {other} arguments. It requires 3 or 
4."
+            )
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
 
     use crate::string_expressions;
     use arrow::{array::Int32Array, datatypes::Int32Type};
+    use arrow_array::Int64Array;
 
     use super::*;
 
@@ -599,4 +690,21 @@ mod tests {
 
         Ok(())
     }
+
+    #[test]
+    fn to_overlay() -> Result<()> {
+        let string =
+            Arc::new(StringArray::from(vec!["123", "abcdefg", "xyz", 
"Txxxxas"]));
+        let replace_string =
+            Arc::new(StringArray::from(vec!["abc", "qwertyasdfg", "ijk", 
"hom"]));
+        let start = Arc::new(Int64Array::from(vec![4, 1, 1, 2])); // start
+        let end = Arc::new(Int64Array::from(vec![5, 7, 2, 4])); // replace len
+
+        let res = overlay::<i32>(&[string, replace_string, start, 
end]).unwrap();
+        let result = as_generic_string_array::<i32>(&res).unwrap();
+        let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz", 
"Thomas"]);
+        assert_eq!(&expected, result);
+
+        Ok(())
+    }
 }
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 5d7c570bc1..d85678a76b 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -636,6 +636,7 @@ enum ScalarFunction {
   ToTimestampNanos = 118;
   ArrayIntersect = 119;
   ArrayUnion = 120;
+  OverLay = 121;
 }
 
 message ScalarFunctionNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 12fa73205d..64db9137d6 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -20935,6 +20935,7 @@ impl serde::Serialize for ScalarFunction {
             Self::ToTimestampNanos => "ToTimestampNanos",
             Self::ArrayIntersect => "ArrayIntersect",
             Self::ArrayUnion => "ArrayUnion",
+            Self::OverLay => "OverLay",
         };
         serializer.serialize_str(variant)
     }
@@ -21067,6 +21068,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
             "ToTimestampNanos",
             "ArrayIntersect",
             "ArrayUnion",
+            "OverLay",
         ];
 
         struct GeneratedVisitor;
@@ -21228,6 +21230,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
                     "ToTimestampNanos" => Ok(ScalarFunction::ToTimestampNanos),
                     "ArrayIntersect" => Ok(ScalarFunction::ArrayIntersect),
                     "ArrayUnion" => Ok(ScalarFunction::ArrayUnion),
+                    "OverLay" => Ok(ScalarFunction::OverLay),
                     _ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
                 }
             }
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 23be5d9088..131ca11993 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2567,6 +2567,7 @@ pub enum ScalarFunction {
     ToTimestampNanos = 118,
     ArrayIntersect = 119,
     ArrayUnion = 120,
+    OverLay = 121,
 }
 impl ScalarFunction {
     /// String value of the enum field names used in the ProtoBuf definition.
@@ -2696,6 +2697,7 @@ impl ScalarFunction {
             ScalarFunction::ToTimestampNanos => "ToTimestampNanos",
             ScalarFunction::ArrayIntersect => "ArrayIntersect",
             ScalarFunction::ArrayUnion => "ArrayUnion",
+            ScalarFunction::OverLay => "OverLay",
         }
     }
     /// Creates an enum from field names used in the ProtoBuf definition.
@@ -2822,6 +2824,7 @@ impl ScalarFunction {
             "ToTimestampNanos" => Some(Self::ToTimestampNanos),
             "ArrayIntersect" => Some(Self::ArrayIntersect),
             "ArrayUnion" => Some(Self::ArrayUnion),
+            "OverLay" => Some(Self::OverLay),
             _ => None,
         }
     }
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index 0ecbe05e79..9ca7bb0e89 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -52,10 +52,10 @@ use datafusion_expr::{
     factorial, flatten, floor, from_unixtime, gcd, isnan, iszero, lcm, left, 
ln, log,
     log10, log2,
     logical_plan::{PlanType, StringifiedPlan},
-    lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, pi, power, 
radians,
-    random, regexp_match, regexp_replace, repeat, replace, reverse, right, 
round, rpad,
-    rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt,
-    starts_with, string_to_array, strpos, struct_fun, substr, substring, tan, 
tanh,
+    lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, 
power,
+    radians, random, regexp_match, regexp_replace, repeat, replace, reverse, 
right,
+    round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, 
split_part,
+    sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substring, 
tan, tanh,
     to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos,
     to_timestamp_seconds, translate, trim, trunc, upper, uuid,
     window_frame::regularize,
@@ -546,6 +546,7 @@ impl From<&protobuf::ScalarFunction> for 
BuiltinScalarFunction {
             ScalarFunction::Isnan => Self::Isnan,
             ScalarFunction::Iszero => Self::Iszero,
             ScalarFunction::ArrowTypeof => Self::ArrowTypeof,
+            ScalarFunction::OverLay => Self::OverLay,
         }
     }
 }
@@ -1680,6 +1681,12 @@ pub fn parse_expr(
                     parse_expr(&args[1], registry)?,
                     parse_expr(&args[2], registry)?,
                 )),
+                ScalarFunction::OverLay => Ok(overlay(
+                    args.to_owned()
+                        .iter()
+                        .map(|expr| parse_expr(expr, registry))
+                        .collect::<Result<Vec<_>, _>>()?,
+                )),
                 ScalarFunction::StructFun => {
                     Ok(struct_fun(parse_expr(&args[0], registry)?))
                 }
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs 
b/datafusion/proto/src/logical_plan/to_proto.rs
index 4c81ab954a..974d6c5aab 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1553,6 +1553,7 @@ impl TryFrom<&BuiltinScalarFunction> for 
protobuf::ScalarFunction {
             BuiltinScalarFunction::Isnan => Self::Isnan,
             BuiltinScalarFunction::Iszero => Self::Iszero,
             BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof,
+            BuiltinScalarFunction::OverLay => Self::OverLay,
         };
 
         Ok(scalar_function)
diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs
index 1cf0fc133f..7fa16ced39 100644
--- a/datafusion/sql/src/expr/mod.rs
+++ b/datafusion/sql/src/expr/mod.rs
@@ -459,7 +459,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 schema,
                 planner_context,
             ),
-
+            SQLExpr::Overlay {
+                expr,
+                overlay_what,
+                overlay_from,
+                overlay_for,
+            } => self.sql_overlay_to_expr(
+                *expr,
+                *overlay_what,
+                *overlay_from,
+                overlay_for,
+                schema,
+                planner_context,
+            ),
             SQLExpr::Nested(e) => {
                 self.sql_expr_to_logical_expr(*e, schema, planner_context)
             }
@@ -645,6 +657,32 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args)))
     }
 
+    fn sql_overlay_to_expr(
+        &self,
+        expr: SQLExpr,
+        overlay_what: SQLExpr,
+        overlay_from: SQLExpr,
+        overlay_for: Option<Box<SQLExpr>>,
+        schema: &DFSchema,
+        planner_context: &mut PlannerContext,
+    ) -> Result<Expr> {
+        let fun = BuiltinScalarFunction::OverLay;
+        let arg = self.sql_expr_to_logical_expr(expr, schema, 
planner_context)?;
+        let what_arg =
+            self.sql_expr_to_logical_expr(overlay_what, schema, 
planner_context)?;
+        let from_arg =
+            self.sql_expr_to_logical_expr(overlay_from, schema, 
planner_context)?;
+        let args = match overlay_for {
+            Some(for_expr) => {
+                let for_expr =
+                    self.sql_expr_to_logical_expr(*for_expr, schema, 
planner_context)?;
+                vec![arg, what_arg, from_arg, for_expr]
+            }
+            None => vec![arg, what_arg, from_arg],
+        };
+        Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args)))
+    }
+
     fn sql_agg_with_filter_to_expr(
         &self,
         expr: SQLExpr,
diff --git a/datafusion/sqllogictest/test_files/functions.slt 
b/datafusion/sqllogictest/test_files/functions.slt
index 2054752cc5..8f42304384 100644
--- a/datafusion/sqllogictest/test_files/functions.slt
+++ b/datafusion/sqllogictest/test_files/functions.slt
@@ -815,3 +815,45 @@ SELECT products.* REPLACE (price*2 AS price, 
product_id+1000 AS product_id) FROM
 1002 OldBrand Product 2 59.98
 1003 OldBrand Product 3 79.98
 1004 OldBrand Product 4 99.98
+
+#overlay tests
+statement ok
+CREATE TABLE over_test(
+  str TEXT,
+  characters TEXT,
+  pos INT,
+  len INT
+) as VALUES
+  ('123', 'abc', 4, 5),
+  ('abcdefg', 'qwertyasdfg', 1, 7),
+  ('xyz', 'ijk', 1, 2),
+  ('Txxxxas', 'hom', 2, 4),
+  (NULL, 'hom', 2, 4),
+  ('Txxxxas', 'hom', NULL, 4),
+  ('Txxxxas', 'hom', 2, NULL),
+  ('Txxxxas', NULL, 2, 4)
+;
+
+query T
+SELECT overlay(str placing characters from pos for len) from over_test
+----
+abc
+qwertyasdfg
+ijkz
+Thomas
+NULL
+NULL
+NULL
+NULL
+
+query T
+SELECT overlay(str placing characters from pos) from over_test
+----
+abc
+qwertyasdfg
+ijk
+Thomxas
+NULL
+NULL
+Thomxas
+NULL
diff --git a/docs/source/user-guide/sql/scalar_functions.md 
b/docs/source/user-guide/sql/scalar_functions.md
index 2959e82024..099c903122 100644
--- a/docs/source/user-guide/sql/scalar_functions.md
+++ b/docs/source/user-guide/sql/scalar_functions.md
@@ -635,6 +635,7 @@ nullif(expression1, expression2)
 - [trim](#trim)
 - [upper](#upper)
 - [uuid](#uuid)
+- [overlay](#overlay)
 
 ### `ascii`
 
@@ -1120,6 +1121,22 @@ Returns UUID v4 string value which is unique per row.
 uuid()
 ```
 
+### `overlay`
+
+Returns the string which is replaced by another string from the specified 
position and specified count length.
+For example, `overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas`
+
+```
+overlay(str PLACING substr FROM pos [FOR count])
+```
+
+#### Arguments
+
+- **str**: String expression to operate on.
+- **substr**: the string to replace part of str.
+- **pos**: the start position to replace of str.
+- **count**: the count of characters to be replaced from start position of 
str. If not specified, will use substr length instead.
+
 ## Binary String Functions
 
 - [decode](#decode)

Reply via email to