This is an automated email from the ASF dual-hosted git repository.
yangjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 7618e4d9c1 feat:implement calcite style 'levenshtein' string function
(#8168)
7618e4d9c1 is described below
commit 7618e4d9c1801d76335164a1e70960d37012c516
Author: Syleechan <[email protected]>
AuthorDate: Fri Nov 17 11:35:03 2023 +0800
feat:implement calcite style 'levenshtein' string function (#8168)
* feat:implement calcite style 'levenshtein' string function
* format doc style
* cargo lock
---
datafusion/expr/src/built_in_function.rs | 12 ++++
datafusion/expr/src/expr_fn.rs | 2 +
datafusion/physical-expr/src/functions.rs | 13 +++++
datafusion/physical-expr/src/string_expressions.rs | 67 +++++++++++++++++++++-
datafusion/proto/proto/datafusion.proto | 1 +
datafusion/proto/src/generated/pbjson.rs | 3 +
datafusion/proto/src/generated/prost.rs | 3 +
datafusion/proto/src/logical_plan/from_proto.rs | 7 ++-
datafusion/proto/src/logical_plan/to_proto.rs | 1 +
datafusion/sqllogictest/test_files/functions.slt | 22 ++++++-
docs/source/user-guide/sql/scalar_functions.md | 15 +++++
11 files changed, 142 insertions(+), 4 deletions(-)
diff --git a/datafusion/expr/src/built_in_function.rs
b/datafusion/expr/src/built_in_function.rs
index 1b48c37406..fc6f9c28e1 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -298,6 +298,8 @@ pub enum BuiltinScalarFunction {
ArrowTypeof,
/// overlay
OverLay,
+ /// levenshtein
+ Levenshtein,
}
/// Maps the sql function name to `BuiltinScalarFunction`
@@ -464,6 +466,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::FromUnixtime => Volatility::Immutable,
BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable,
BuiltinScalarFunction::OverLay => Volatility::Immutable,
+ BuiltinScalarFunction::Levenshtein => Volatility::Immutable,
// Stable builtin functions
BuiltinScalarFunction::Now => Volatility::Stable,
@@ -829,6 +832,10 @@ impl BuiltinScalarFunction {
utf8_to_str_type(&input_expr_types[0], "overlay")
}
+ BuiltinScalarFunction::Levenshtein => {
+ utf8_to_int_type(&input_expr_types[0], "levenshtein")
+ }
+
BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
@@ -1293,6 +1300,10 @@ impl BuiltinScalarFunction {
],
self.volatility(),
),
+ BuiltinScalarFunction::Levenshtein => Signature::one_of(
+ vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8,
LargeUtf8])],
+ self.volatility(),
+ ),
BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
@@ -1457,6 +1468,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static
[&'static str] {
BuiltinScalarFunction::Trim => &["trim"],
BuiltinScalarFunction::Upper => &["upper"],
BuiltinScalarFunction::Uuid => &["uuid"],
+ BuiltinScalarFunction::Levenshtein => &["levenshtein"],
// regex functions
BuiltinScalarFunction::RegexpMatch => &["regexp_match"],
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index bcf1aa0ca7..75b7628044 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -909,6 +909,7 @@ scalar_expr!(
);
scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type");
+scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the
Levenshtein distance between the two given strings");
scalar_expr!(
Struct,
@@ -1195,6 +1196,7 @@ mod test {
test_unary_scalar_expr!(ArrowTypeof, arrow_typeof);
test_nary_scalar_expr!(OverLay, overlay, string, characters, position,
len);
test_nary_scalar_expr!(OverLay, overlay, string, characters, position);
+ test_scalar_expr!(Levenshtein, levenshtein, string1, string2);
}
#[test]
diff --git a/datafusion/physical-expr/src/functions.rs
b/datafusion/physical-expr/src/functions.rs
index 1e8500079f..b46249d26d 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -846,6 +846,19 @@ pub fn create_physical_fun(
"Unsupported data type {other:?} for function overlay",
))),
}),
+ BuiltinScalarFunction::Levenshtein => {
+ Arc::new(|args| match args[0].data_type() {
+ DataType::Utf8 => {
+
make_scalar_function(string_expressions::levenshtein::<i32>)(args)
+ }
+ DataType::LargeUtf8 => {
+
make_scalar_function(string_expressions::levenshtein::<i64>)(args)
+ }
+ other => Err(DataFusionError::Internal(format!(
+ "Unsupported data type {other:?} for function levenshtein",
+ ))),
+ })
+ }
})
}
diff --git a/datafusion/physical-expr/src/string_expressions.rs
b/datafusion/physical-expr/src/string_expressions.rs
index 7e954fdcfd..91d21f95e4 100644
--- a/datafusion/physical-expr/src/string_expressions.rs
+++ b/datafusion/physical-expr/src/string_expressions.rs
@@ -23,11 +23,12 @@
use arrow::{
array::{
- Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array,
OffsetSizeTrait,
- StringArray,
+ Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array,
Int64Array,
+ OffsetSizeTrait, StringArray,
},
datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType},
};
+use datafusion_common::utils::datafusion_strsim;
use datafusion_common::{
cast::{
as_generic_string_array, as_int64_array, as_primitive_array,
as_string_array,
@@ -643,12 +644,59 @@ pub fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) ->
Result<ArrayRef> {
}
}
+///Returns the Levenshtein distance between the two given strings.
+/// LEVENSHTEIN('kitten', 'sitting') = 3
+pub fn levenshtein<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
+ if args.len() != 2 {
+ return Err(DataFusionError::Internal(format!(
+ "levenshtein function requires two arguments, got {}",
+ args.len()
+ )));
+ }
+ let str1_array = as_generic_string_array::<T>(&args[0])?;
+ let str2_array = as_generic_string_array::<T>(&args[1])?;
+ match args[0].data_type() {
+ DataType::Utf8 => {
+ let result = str1_array
+ .iter()
+ .zip(str2_array.iter())
+ .map(|(string1, string2)| match (string1, string2) {
+ (Some(string1), Some(string2)) => {
+ Some(datafusion_strsim::levenshtein(string1, string2)
as i32)
+ }
+ _ => None,
+ })
+ .collect::<Int32Array>();
+ Ok(Arc::new(result) as ArrayRef)
+ }
+ DataType::LargeUtf8 => {
+ let result = str1_array
+ .iter()
+ .zip(str2_array.iter())
+ .map(|(string1, string2)| match (string1, string2) {
+ (Some(string1), Some(string2)) => {
+ Some(datafusion_strsim::levenshtein(string1, string2)
as i64)
+ }
+ _ => None,
+ })
+ .collect::<Int64Array>();
+ Ok(Arc::new(result) as ArrayRef)
+ }
+ other => {
+ internal_err!(
+ "levenshtein was called with {other} datatype arguments. It
requires Utf8 or LargeUtf8."
+ )
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use crate::string_expressions;
use arrow::{array::Int32Array, datatypes::Int32Type};
use arrow_array::Int64Array;
+ use datafusion_common::cast::as_int32_array;
use super::*;
@@ -707,4 +755,19 @@ mod tests {
Ok(())
}
+
+ #[test]
+ fn to_levenshtein() -> Result<()> {
+ let string1_array =
+ Arc::new(StringArray::from(vec!["123", "abc", "xyz", "kitten"]));
+ let string2_array =
+ Arc::new(StringArray::from(vec!["321", "def", "zyx", "sitting"]));
+ let res = levenshtein::<i32>(&[string1_array, string2_array]).unwrap();
+ let result =
+ as_int32_array(&res).expect("failed to initialized function
levenshtein");
+ let expected = Int32Array::from(vec![2, 3, 2, 3]);
+ assert_eq!(&expected, result);
+
+ Ok(())
+ }
}
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 66c34c7a12..a5c3d3b603 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -639,6 +639,7 @@ enum ScalarFunction {
OverLay = 121;
Range = 122;
ArrayPopFront = 123;
+ Levenshtein = 124;
}
message ScalarFunctionNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 628adcc411..3faacca18c 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -20938,6 +20938,7 @@ impl serde::Serialize for ScalarFunction {
Self::OverLay => "OverLay",
Self::Range => "Range",
Self::ArrayPopFront => "ArrayPopFront",
+ Self::Levenshtein => "Levenshtein",
};
serializer.serialize_str(variant)
}
@@ -21073,6 +21074,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
"OverLay",
"Range",
"ArrayPopFront",
+ "Levenshtein",
];
struct GeneratedVisitor;
@@ -21237,6 +21239,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
"OverLay" => Ok(ScalarFunction::OverLay),
"Range" => Ok(ScalarFunction::Range),
"ArrayPopFront" => Ok(ScalarFunction::ArrayPopFront),
+ "Levenshtein" => Ok(ScalarFunction::Levenshtein),
_ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
}
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 317b888447..2555a31f6f 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2570,6 +2570,7 @@ pub enum ScalarFunction {
OverLay = 121,
Range = 122,
ArrayPopFront = 123,
+ Levenshtein = 124,
}
impl ScalarFunction {
/// String value of the enum field names used in the ProtoBuf definition.
@@ -2702,6 +2703,7 @@ impl ScalarFunction {
ScalarFunction::OverLay => "OverLay",
ScalarFunction::Range => "Range",
ScalarFunction::ArrayPopFront => "ArrayPopFront",
+ ScalarFunction::Levenshtein => "Levenshtein",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
@@ -2831,6 +2833,7 @@ impl ScalarFunction {
"OverLay" => Some(Self::OverLay),
"Range" => Some(Self::Range),
"ArrayPopFront" => Some(Self::ArrayPopFront),
+ "Levenshtein" => Some(Self::Levenshtein),
_ => None,
}
}
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 94c9f98066..f14da70485 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -50,7 +50,7 @@ use datafusion_expr::{
date_part, date_trunc, decode, degrees, digest, encode, exp,
expr::{self, InList, Sort, WindowFunction},
factorial, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero,
lcm, left,
- ln, log, log10, log2,
+ levenshtein, ln, log, log10, log2,
logical_plan::{PlanType, StringifiedPlan},
lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi,
power,
radians, random, regexp_match, regexp_replace, repeat, replace, reverse,
right,
@@ -549,6 +549,7 @@ impl From<&protobuf::ScalarFunction> for
BuiltinScalarFunction {
ScalarFunction::Iszero => Self::Iszero,
ScalarFunction::ArrowTypeof => Self::ArrowTypeof,
ScalarFunction::OverLay => Self::OverLay,
+ ScalarFunction::Levenshtein => Self::Levenshtein,
}
}
}
@@ -1630,6 +1631,10 @@ pub fn parse_expr(
))
}
}
+ ScalarFunction::Levenshtein => Ok(levenshtein(
+ parse_expr(&args[0], registry)?,
+ parse_expr(&args[1], registry)?,
+ )),
ScalarFunction::ToHex => Ok(to_hex(parse_expr(&args[0],
registry)?)),
ScalarFunction::ToTimestampMillis => {
Ok(to_timestamp_millis(parse_expr(&args[0], registry)?))
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index 649be05b88..de81a1f4ca 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1556,6 +1556,7 @@ impl TryFrom<&BuiltinScalarFunction> for
protobuf::ScalarFunction {
BuiltinScalarFunction::Iszero => Self::Iszero,
BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof,
BuiltinScalarFunction::OverLay => Self::OverLay,
+ BuiltinScalarFunction::Levenshtein => Self::Levenshtein,
};
Ok(scalar_function)
diff --git a/datafusion/sqllogictest/test_files/functions.slt
b/datafusion/sqllogictest/test_files/functions.slt
index 8f42304384..9c8bb2c5f8 100644
--- a/datafusion/sqllogictest/test_files/functions.slt
+++ b/datafusion/sqllogictest/test_files/functions.slt
@@ -788,7 +788,7 @@ INSERT INTO products (product_id, product_name, price)
VALUES
(1, 'OldBrand Product 1', 19.99),
(2, 'OldBrand Product 2', 29.99),
(3, 'OldBrand Product 3', 39.99),
-(4, 'OldBrand Product 4', 49.99)
+(4, 'OldBrand Product 4', 49.99)
query ITR
SELECT * REPLACE (price*2 AS price) FROM products
@@ -857,3 +857,23 @@ NULL
NULL
Thomxas
NULL
+
+query I
+SELECT levenshtein('kitten', 'sitting')
+----
+3
+
+query I
+SELECT levenshtein('kitten', NULL)
+----
+NULL
+
+query ?
+SELECT levenshtein(NULL, 'sitting')
+----
+NULL
+
+query ?
+SELECT levenshtein(NULL, NULL)
+----
+NULL
diff --git a/docs/source/user-guide/sql/scalar_functions.md
b/docs/source/user-guide/sql/scalar_functions.md
index baaea3926f..f9f45a1b0a 100644
--- a/docs/source/user-guide/sql/scalar_functions.md
+++ b/docs/source/user-guide/sql/scalar_functions.md
@@ -636,6 +636,7 @@ nullif(expression1, expression2)
- [upper](#upper)
- [uuid](#uuid)
- [overlay](#overlay)
+- [levenshtein](#levenshtein)
### `ascii`
@@ -1137,6 +1138,20 @@ overlay(str PLACING substr FROM pos [FOR count])
- **pos**: the start position to replace of str.
- **count**: the count of characters to be replaced from start position of
str. If not specified, will use substr length instead.
+### `levenshtein`
+
+Returns the Levenshtein distance between the two given strings.
+For example, `levenshtein('kitten', 'sitting') = 3`
+
+```
+levenshtein(str1, str2)
+```
+
+#### Arguments
+
+- **str1**: String expression to compute Levenshtein distance with str2.
+- **str2**: String expression to compute Levenshtein distance with str1.
+
## Binary String Functions
- [decode](#decode)