This is an automated email from the ASF dual-hosted git repository.
yangjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 4535551120 feat:implement postgres style 'overlay' string function
(#8117)
4535551120 is described below
commit 4535551120dd4c31c160a35851b9e4a33514a44f
Author: Syleechan <[email protected]>
AuthorDate: Tue Nov 14 15:31:25 2023 +0800
feat:implement postgres style 'overlay' string function (#8117)
* feat:implement posgres style 'overlay' string function
* code format
* code format
* code format
* code format
* add sql slt test
* fix modify other case issue
* add test expr
* add annotation
* add overlay function sql reference doc
* add sql case and format doc
---
datafusion/expr/src/built_in_function.rs | 18 +++-
datafusion/expr/src/expr_fn.rs | 7 ++
datafusion/physical-expr/src/functions.rs | 11 +++
datafusion/physical-expr/src/string_expressions.rs | 108 +++++++++++++++++++++
datafusion/proto/proto/datafusion.proto | 1 +
datafusion/proto/src/generated/pbjson.rs | 3 +
datafusion/proto/src/generated/prost.rs | 3 +
datafusion/proto/src/logical_plan/from_proto.rs | 15 ++-
datafusion/proto/src/logical_plan/to_proto.rs | 1 +
datafusion/sql/src/expr/mod.rs | 40 +++++++-
datafusion/sqllogictest/test_files/functions.slt | 42 ++++++++
docs/source/user-guide/sql/scalar_functions.md | 17 ++++
12 files changed, 260 insertions(+), 6 deletions(-)
diff --git a/datafusion/expr/src/built_in_function.rs
b/datafusion/expr/src/built_in_function.rs
index 0d2c1f2e3c..77c64128e1 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -292,6 +292,8 @@ pub enum BuiltinScalarFunction {
RegexpMatch,
/// arrow_typeof
ArrowTypeof,
+ /// overlay
+ OverLay,
}
/// Maps the sql function name to `BuiltinScalarFunction`
@@ -455,6 +457,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Struct => Volatility::Immutable,
BuiltinScalarFunction::FromUnixtime => Volatility::Immutable,
BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable,
+ BuiltinScalarFunction::OverLay => Volatility::Immutable,
// Stable builtin functions
BuiltinScalarFunction::Now => Volatility::Stable,
@@ -812,6 +815,10 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Abs => Ok(input_expr_types[0].clone()),
+ BuiltinScalarFunction::OverLay => {
+ utf8_to_str_type(&input_expr_types[0], "overlay")
+ }
+
BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
@@ -1258,7 +1265,15 @@ impl BuiltinScalarFunction {
}
BuiltinScalarFunction::ArrowTypeof => Signature::any(1,
self.volatility()),
BuiltinScalarFunction::Abs => Signature::any(1, self.volatility()),
-
+ BuiltinScalarFunction::OverLay => Signature::one_of(
+ vec![
+ Exact(vec![Utf8, Utf8, Int64, Int64]),
+ Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]),
+ Exact(vec![Utf8, Utf8, Int64]),
+ Exact(vec![LargeUtf8, LargeUtf8, Int64]),
+ ],
+ self.volatility(),
+ ),
BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
@@ -1517,6 +1532,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static
[&'static str] {
BuiltinScalarFunction::Cardinality => &["cardinality"],
BuiltinScalarFunction::MakeArray => &["make_array", "make_list"],
BuiltinScalarFunction::ArrayIntersect => &["array_intersect",
"list_intersect"],
+ BuiltinScalarFunction::OverLay => &["overlay"],
// struct functions
BuiltinScalarFunction::Struct => &["struct"],
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 0d920beb41..91674cc092 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -838,6 +838,11 @@ nary_scalar_expr!(
"concatenates several strings, placing a seperator between each one"
);
nary_scalar_expr!(Concat, concat_expr, "concatenates several strings");
+nary_scalar_expr!(
+ OverLay,
+ overlay,
+ "replace the substring of string that starts at the start'th character and
extends for count characters with new substring"
+);
// date functions
scalar_expr!(DatePart, date_part, part date, "extracts a subfield from the
date");
@@ -1174,6 +1179,8 @@ mod test {
test_nary_scalar_expr!(MakeArray, array, input);
test_unary_scalar_expr!(ArrowTypeof, arrow_typeof);
+ test_nary_scalar_expr!(OverLay, overlay, string, characters, position,
len);
+ test_nary_scalar_expr!(OverLay, overlay, string, characters, position);
}
#[test]
diff --git a/datafusion/physical-expr/src/functions.rs
b/datafusion/physical-expr/src/functions.rs
index 80c0eaf054..7f8921e86c 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -829,6 +829,17 @@ pub fn create_physical_fun(
"{input_data_type}"
)))))
}),
+ BuiltinScalarFunction::OverLay => Arc::new(|args| match
args[0].data_type() {
+ DataType::Utf8 => {
+ make_scalar_function(string_expressions::overlay::<i32>)(args)
+ }
+ DataType::LargeUtf8 => {
+ make_scalar_function(string_expressions::overlay::<i64>)(args)
+ }
+ other => Err(DataFusionError::Internal(format!(
+ "Unsupported data type {other:?} for function overlay",
+ ))),
+ }),
})
}
diff --git a/datafusion/physical-expr/src/string_expressions.rs
b/datafusion/physical-expr/src/string_expressions.rs
index e6a3d5c331..7e954fdcfd 100644
--- a/datafusion/physical-expr/src/string_expressions.rs
+++ b/datafusion/physical-expr/src/string_expressions.rs
@@ -553,11 +553,102 @@ pub fn uuid(args: &[ColumnarValue]) ->
Result<ColumnarValue> {
Ok(ColumnarValue::Array(Arc::new(array)))
}
+/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2)
+/// Replaces a substring of string1 with string2 starting at the integer bit
+/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas
+/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option,
str2's len is instead
+pub fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
+ match args.len() {
+ 3 => {
+ let string_array = as_generic_string_array::<T>(&args[0])?;
+ let characters_array = as_generic_string_array::<T>(&args[1])?;
+ let pos_num = as_int64_array(&args[2])?;
+
+ let result = string_array
+ .iter()
+ .zip(characters_array.iter())
+ .zip(pos_num.iter())
+ .map(|((string, characters), start_pos)| {
+ match (string, characters, start_pos) {
+ (Some(string), Some(characters), Some(start_pos)) => {
+ let string_len = string.chars().count();
+ let characters_len = characters.chars().count();
+ let replace_len = characters_len as i64;
+ let mut res =
+
String::with_capacity(string_len.max(characters_len));
+
+ //as sql replace index start from 1 while string
index start from 0
+ if start_pos > 1 && start_pos - 1 < string_len as
i64 {
+ let start = (start_pos - 1) as usize;
+ res.push_str(&string[..start]);
+ }
+ res.push_str(characters);
+ // if start + replace_len - 1 >= string_length,
just to string end
+ if start_pos + replace_len - 1 < string_len as i64
{
+ let end = (start_pos + replace_len - 1) as
usize;
+ res.push_str(&string[end..]);
+ }
+ Ok(Some(res))
+ }
+ _ => Ok(None),
+ }
+ })
+ .collect::<Result<GenericStringArray<T>>>()?;
+ Ok(Arc::new(result) as ArrayRef)
+ }
+ 4 => {
+ let string_array = as_generic_string_array::<T>(&args[0])?;
+ let characters_array = as_generic_string_array::<T>(&args[1])?;
+ let pos_num = as_int64_array(&args[2])?;
+ let len_num = as_int64_array(&args[3])?;
+
+ let result = string_array
+ .iter()
+ .zip(characters_array.iter())
+ .zip(pos_num.iter())
+ .zip(len_num.iter())
+ .map(|(((string, characters), start_pos), len)| {
+ match (string, characters, start_pos, len) {
+ (Some(string), Some(characters), Some(start_pos),
Some(len)) => {
+ let string_len = string.chars().count();
+ let characters_len = characters.chars().count();
+ let replace_len = len.min(string_len as i64);
+ let mut res =
+
String::with_capacity(string_len.max(characters_len));
+
+ //as sql replace index start from 1 while string
index start from 0
+ if start_pos > 1 && start_pos - 1 < string_len as
i64 {
+ let start = (start_pos - 1) as usize;
+ res.push_str(&string[..start]);
+ }
+ res.push_str(characters);
+ // if start + replace_len - 1 >= string_length,
just to string end
+ if start_pos + replace_len - 1 < string_len as i64
{
+ let end = (start_pos + replace_len - 1) as
usize;
+ res.push_str(&string[end..]);
+ }
+ Ok(Some(res))
+ }
+ _ => Ok(None),
+ }
+ })
+ .collect::<Result<GenericStringArray<T>>>()?;
+ Ok(Arc::new(result) as ArrayRef)
+ }
+ other => {
+ internal_err!(
+ "overlay was called with {other} arguments. It requires 3 or
4."
+ )
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use crate::string_expressions;
use arrow::{array::Int32Array, datatypes::Int32Type};
+ use arrow_array::Int64Array;
use super::*;
@@ -599,4 +690,21 @@ mod tests {
Ok(())
}
+
+ #[test]
+ fn to_overlay() -> Result<()> {
+ let string =
+ Arc::new(StringArray::from(vec!["123", "abcdefg", "xyz",
"Txxxxas"]));
+ let replace_string =
+ Arc::new(StringArray::from(vec!["abc", "qwertyasdfg", "ijk",
"hom"]));
+ let start = Arc::new(Int64Array::from(vec![4, 1, 1, 2])); // start
+ let end = Arc::new(Int64Array::from(vec![5, 7, 2, 4])); // replace len
+
+ let res = overlay::<i32>(&[string, replace_string, start,
end]).unwrap();
+ let result = as_generic_string_array::<i32>(&res).unwrap();
+ let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz",
"Thomas"]);
+ assert_eq!(&expected, result);
+
+ Ok(())
+ }
}
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 5d7c570bc1..d85678a76b 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -636,6 +636,7 @@ enum ScalarFunction {
ToTimestampNanos = 118;
ArrayIntersect = 119;
ArrayUnion = 120;
+ OverLay = 121;
}
message ScalarFunctionNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 12fa73205d..64db9137d6 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -20935,6 +20935,7 @@ impl serde::Serialize for ScalarFunction {
Self::ToTimestampNanos => "ToTimestampNanos",
Self::ArrayIntersect => "ArrayIntersect",
Self::ArrayUnion => "ArrayUnion",
+ Self::OverLay => "OverLay",
};
serializer.serialize_str(variant)
}
@@ -21067,6 +21068,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
"ToTimestampNanos",
"ArrayIntersect",
"ArrayUnion",
+ "OverLay",
];
struct GeneratedVisitor;
@@ -21228,6 +21230,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
"ToTimestampNanos" => Ok(ScalarFunction::ToTimestampNanos),
"ArrayIntersect" => Ok(ScalarFunction::ArrayIntersect),
"ArrayUnion" => Ok(ScalarFunction::ArrayUnion),
+ "OverLay" => Ok(ScalarFunction::OverLay),
_ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
}
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 23be5d9088..131ca11993 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2567,6 +2567,7 @@ pub enum ScalarFunction {
ToTimestampNanos = 118,
ArrayIntersect = 119,
ArrayUnion = 120,
+ OverLay = 121,
}
impl ScalarFunction {
/// String value of the enum field names used in the ProtoBuf definition.
@@ -2696,6 +2697,7 @@ impl ScalarFunction {
ScalarFunction::ToTimestampNanos => "ToTimestampNanos",
ScalarFunction::ArrayIntersect => "ArrayIntersect",
ScalarFunction::ArrayUnion => "ArrayUnion",
+ ScalarFunction::OverLay => "OverLay",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
@@ -2822,6 +2824,7 @@ impl ScalarFunction {
"ToTimestampNanos" => Some(Self::ToTimestampNanos),
"ArrayIntersect" => Some(Self::ArrayIntersect),
"ArrayUnion" => Some(Self::ArrayUnion),
+ "OverLay" => Some(Self::OverLay),
_ => None,
}
}
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 0ecbe05e79..9ca7bb0e89 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -52,10 +52,10 @@ use datafusion_expr::{
factorial, flatten, floor, from_unixtime, gcd, isnan, iszero, lcm, left,
ln, log,
log10, log2,
logical_plan::{PlanType, StringifiedPlan},
- lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, pi, power,
radians,
- random, regexp_match, regexp_replace, repeat, replace, reverse, right,
round, rpad,
- rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt,
- starts_with, string_to_array, strpos, struct_fun, substr, substring, tan,
tanh,
+ lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi,
power,
+ radians, random, regexp_match, regexp_replace, repeat, replace, reverse,
right,
+ round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh,
split_part,
+ sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substring,
tan, tanh,
to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos,
to_timestamp_seconds, translate, trim, trunc, upper, uuid,
window_frame::regularize,
@@ -546,6 +546,7 @@ impl From<&protobuf::ScalarFunction> for
BuiltinScalarFunction {
ScalarFunction::Isnan => Self::Isnan,
ScalarFunction::Iszero => Self::Iszero,
ScalarFunction::ArrowTypeof => Self::ArrowTypeof,
+ ScalarFunction::OverLay => Self::OverLay,
}
}
}
@@ -1680,6 +1681,12 @@ pub fn parse_expr(
parse_expr(&args[1], registry)?,
parse_expr(&args[2], registry)?,
)),
+ ScalarFunction::OverLay => Ok(overlay(
+ args.to_owned()
+ .iter()
+ .map(|expr| parse_expr(expr, registry))
+ .collect::<Result<Vec<_>, _>>()?,
+ )),
ScalarFunction::StructFun => {
Ok(struct_fun(parse_expr(&args[0], registry)?))
}
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index 4c81ab954a..974d6c5aab 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1553,6 +1553,7 @@ impl TryFrom<&BuiltinScalarFunction> for
protobuf::ScalarFunction {
BuiltinScalarFunction::Isnan => Self::Isnan,
BuiltinScalarFunction::Iszero => Self::Iszero,
BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof,
+ BuiltinScalarFunction::OverLay => Self::OverLay,
};
Ok(scalar_function)
diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs
index 1cf0fc133f..7fa16ced39 100644
--- a/datafusion/sql/src/expr/mod.rs
+++ b/datafusion/sql/src/expr/mod.rs
@@ -459,7 +459,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
schema,
planner_context,
),
-
+ SQLExpr::Overlay {
+ expr,
+ overlay_what,
+ overlay_from,
+ overlay_for,
+ } => self.sql_overlay_to_expr(
+ *expr,
+ *overlay_what,
+ *overlay_from,
+ overlay_for,
+ schema,
+ planner_context,
+ ),
SQLExpr::Nested(e) => {
self.sql_expr_to_logical_expr(*e, schema, planner_context)
}
@@ -645,6 +657,32 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args)))
}
+ fn sql_overlay_to_expr(
+ &self,
+ expr: SQLExpr,
+ overlay_what: SQLExpr,
+ overlay_from: SQLExpr,
+ overlay_for: Option<Box<SQLExpr>>,
+ schema: &DFSchema,
+ planner_context: &mut PlannerContext,
+ ) -> Result<Expr> {
+ let fun = BuiltinScalarFunction::OverLay;
+ let arg = self.sql_expr_to_logical_expr(expr, schema,
planner_context)?;
+ let what_arg =
+ self.sql_expr_to_logical_expr(overlay_what, schema,
planner_context)?;
+ let from_arg =
+ self.sql_expr_to_logical_expr(overlay_from, schema,
planner_context)?;
+ let args = match overlay_for {
+ Some(for_expr) => {
+ let for_expr =
+ self.sql_expr_to_logical_expr(*for_expr, schema,
planner_context)?;
+ vec![arg, what_arg, from_arg, for_expr]
+ }
+ None => vec![arg, what_arg, from_arg],
+ };
+ Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args)))
+ }
+
fn sql_agg_with_filter_to_expr(
&self,
expr: SQLExpr,
diff --git a/datafusion/sqllogictest/test_files/functions.slt
b/datafusion/sqllogictest/test_files/functions.slt
index 2054752cc5..8f42304384 100644
--- a/datafusion/sqllogictest/test_files/functions.slt
+++ b/datafusion/sqllogictest/test_files/functions.slt
@@ -815,3 +815,45 @@ SELECT products.* REPLACE (price*2 AS price,
product_id+1000 AS product_id) FROM
1002 OldBrand Product 2 59.98
1003 OldBrand Product 3 79.98
1004 OldBrand Product 4 99.98
+
+#overlay tests
+statement ok
+CREATE TABLE over_test(
+ str TEXT,
+ characters TEXT,
+ pos INT,
+ len INT
+) as VALUES
+ ('123', 'abc', 4, 5),
+ ('abcdefg', 'qwertyasdfg', 1, 7),
+ ('xyz', 'ijk', 1, 2),
+ ('Txxxxas', 'hom', 2, 4),
+ (NULL, 'hom', 2, 4),
+ ('Txxxxas', 'hom', NULL, 4),
+ ('Txxxxas', 'hom', 2, NULL),
+ ('Txxxxas', NULL, 2, 4)
+;
+
+query T
+SELECT overlay(str placing characters from pos for len) from over_test
+----
+abc
+qwertyasdfg
+ijkz
+Thomas
+NULL
+NULL
+NULL
+NULL
+
+query T
+SELECT overlay(str placing characters from pos) from over_test
+----
+abc
+qwertyasdfg
+ijk
+Thomxas
+NULL
+NULL
+Thomxas
+NULL
diff --git a/docs/source/user-guide/sql/scalar_functions.md
b/docs/source/user-guide/sql/scalar_functions.md
index 2959e82024..099c903122 100644
--- a/docs/source/user-guide/sql/scalar_functions.md
+++ b/docs/source/user-guide/sql/scalar_functions.md
@@ -635,6 +635,7 @@ nullif(expression1, expression2)
- [trim](#trim)
- [upper](#upper)
- [uuid](#uuid)
+- [overlay](#overlay)
### `ascii`
@@ -1120,6 +1121,22 @@ Returns UUID v4 string value which is unique per row.
uuid()
```
+### `overlay`
+
+Returns the string which is replaced by another string from the specified
position and specified count length.
+For example, `overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas`
+
+```
+overlay(str PLACING substr FROM pos [FOR count])
+```
+
+#### Arguments
+
+- **str**: String expression to operate on.
+- **substr**: the string to replace part of str.
+- **pos**: the start position to replace of str.
+- **count**: the count of characters to be replaced from start position of
str. If not specified, will use substr length instead.
+
## Binary String Functions
- [decode](#decode)