This is an automated email from the ASF dual-hosted git repository.

yangjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 7618e4d9c1 feat:implement calcite style 'levenshtein' string function 
(#8168)
7618e4d9c1 is described below

commit 7618e4d9c1801d76335164a1e70960d37012c516
Author: Syleechan <[email protected]>
AuthorDate: Fri Nov 17 11:35:03 2023 +0800

    feat:implement calcite style 'levenshtein' string function (#8168)
    
    * feat:implement calcite style 'levenshtein' string function
    
    * format doc style
    
    * cargo lock
---
 datafusion/expr/src/built_in_function.rs           | 12 ++++
 datafusion/expr/src/expr_fn.rs                     |  2 +
 datafusion/physical-expr/src/functions.rs          | 13 +++++
 datafusion/physical-expr/src/string_expressions.rs | 67 +++++++++++++++++++++-
 datafusion/proto/proto/datafusion.proto            |  1 +
 datafusion/proto/src/generated/pbjson.rs           |  3 +
 datafusion/proto/src/generated/prost.rs            |  3 +
 datafusion/proto/src/logical_plan/from_proto.rs    |  7 ++-
 datafusion/proto/src/logical_plan/to_proto.rs      |  1 +
 datafusion/sqllogictest/test_files/functions.slt   | 22 ++++++-
 docs/source/user-guide/sql/scalar_functions.md     | 15 +++++
 11 files changed, 142 insertions(+), 4 deletions(-)

diff --git a/datafusion/expr/src/built_in_function.rs 
b/datafusion/expr/src/built_in_function.rs
index 1b48c37406..fc6f9c28e1 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -298,6 +298,8 @@ pub enum BuiltinScalarFunction {
     ArrowTypeof,
     /// overlay
     OverLay,
+    /// levenshtein
+    Levenshtein,
 }
 
 /// Maps the sql function name to `BuiltinScalarFunction`
@@ -464,6 +466,7 @@ impl BuiltinScalarFunction {
             BuiltinScalarFunction::FromUnixtime => Volatility::Immutable,
             BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable,
             BuiltinScalarFunction::OverLay => Volatility::Immutable,
+            BuiltinScalarFunction::Levenshtein => Volatility::Immutable,
 
             // Stable builtin functions
             BuiltinScalarFunction::Now => Volatility::Stable,
@@ -829,6 +832,10 @@ impl BuiltinScalarFunction {
                 utf8_to_str_type(&input_expr_types[0], "overlay")
             }
 
+            BuiltinScalarFunction::Levenshtein => {
+                utf8_to_int_type(&input_expr_types[0], "levenshtein")
+            }
+
             BuiltinScalarFunction::Acos
             | BuiltinScalarFunction::Asin
             | BuiltinScalarFunction::Atan
@@ -1293,6 +1300,10 @@ impl BuiltinScalarFunction {
                 ],
                 self.volatility(),
             ),
+            BuiltinScalarFunction::Levenshtein => Signature::one_of(
+                vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, 
LargeUtf8])],
+                self.volatility(),
+            ),
             BuiltinScalarFunction::Acos
             | BuiltinScalarFunction::Asin
             | BuiltinScalarFunction::Atan
@@ -1457,6 +1468,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static 
[&'static str] {
         BuiltinScalarFunction::Trim => &["trim"],
         BuiltinScalarFunction::Upper => &["upper"],
         BuiltinScalarFunction::Uuid => &["uuid"],
+        BuiltinScalarFunction::Levenshtein => &["levenshtein"],
 
         // regex functions
         BuiltinScalarFunction::RegexpMatch => &["regexp_match"],
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index bcf1aa0ca7..75b7628044 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -909,6 +909,7 @@ scalar_expr!(
 );
 
 scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type");
+scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the 
Levenshtein distance between the two given strings");
 
 scalar_expr!(
     Struct,
@@ -1195,6 +1196,7 @@ mod test {
         test_unary_scalar_expr!(ArrowTypeof, arrow_typeof);
         test_nary_scalar_expr!(OverLay, overlay, string, characters, position, 
len);
         test_nary_scalar_expr!(OverLay, overlay, string, characters, position);
+        test_scalar_expr!(Levenshtein, levenshtein, string1, string2);
     }
 
     #[test]
diff --git a/datafusion/physical-expr/src/functions.rs 
b/datafusion/physical-expr/src/functions.rs
index 1e8500079f..b46249d26d 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -846,6 +846,19 @@ pub fn create_physical_fun(
                 "Unsupported data type {other:?} for function overlay",
             ))),
         }),
+        BuiltinScalarFunction::Levenshtein => {
+            Arc::new(|args| match args[0].data_type() {
+                DataType::Utf8 => {
+                    
make_scalar_function(string_expressions::levenshtein::<i32>)(args)
+                }
+                DataType::LargeUtf8 => {
+                    
make_scalar_function(string_expressions::levenshtein::<i64>)(args)
+                }
+                other => Err(DataFusionError::Internal(format!(
+                    "Unsupported data type {other:?} for function levenshtein",
+                ))),
+            })
+        }
     })
 }
 
diff --git a/datafusion/physical-expr/src/string_expressions.rs 
b/datafusion/physical-expr/src/string_expressions.rs
index 7e954fdcfd..91d21f95e4 100644
--- a/datafusion/physical-expr/src/string_expressions.rs
+++ b/datafusion/physical-expr/src/string_expressions.rs
@@ -23,11 +23,12 @@
 
 use arrow::{
     array::{
-        Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, 
OffsetSizeTrait,
-        StringArray,
+        Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, 
Int64Array,
+        OffsetSizeTrait, StringArray,
     },
     datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType},
 };
+use datafusion_common::utils::datafusion_strsim;
 use datafusion_common::{
     cast::{
         as_generic_string_array, as_int64_array, as_primitive_array, 
as_string_array,
@@ -643,12 +644,59 @@ pub fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> 
Result<ArrayRef> {
     }
 }
 
+///Returns the Levenshtein distance between the two given strings.
+/// LEVENSHTEIN('kitten', 'sitting') = 3
+pub fn levenshtein<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
+    if args.len() != 2 {
+        return Err(DataFusionError::Internal(format!(
+            "levenshtein function requires two arguments, got {}",
+            args.len()
+        )));
+    }
+    let str1_array = as_generic_string_array::<T>(&args[0])?;
+    let str2_array = as_generic_string_array::<T>(&args[1])?;
+    match args[0].data_type() {
+        DataType::Utf8 => {
+            let result = str1_array
+                .iter()
+                .zip(str2_array.iter())
+                .map(|(string1, string2)| match (string1, string2) {
+                    (Some(string1), Some(string2)) => {
+                        Some(datafusion_strsim::levenshtein(string1, string2) 
as i32)
+                    }
+                    _ => None,
+                })
+                .collect::<Int32Array>();
+            Ok(Arc::new(result) as ArrayRef)
+        }
+        DataType::LargeUtf8 => {
+            let result = str1_array
+                .iter()
+                .zip(str2_array.iter())
+                .map(|(string1, string2)| match (string1, string2) {
+                    (Some(string1), Some(string2)) => {
+                        Some(datafusion_strsim::levenshtein(string1, string2) 
as i64)
+                    }
+                    _ => None,
+                })
+                .collect::<Int64Array>();
+            Ok(Arc::new(result) as ArrayRef)
+        }
+        other => {
+            internal_err!(
+                "levenshtein was called with {other} datatype arguments. It 
requires Utf8 or LargeUtf8."
+            )
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
 
     use crate::string_expressions;
     use arrow::{array::Int32Array, datatypes::Int32Type};
     use arrow_array::Int64Array;
+    use datafusion_common::cast::as_int32_array;
 
     use super::*;
 
@@ -707,4 +755,19 @@ mod tests {
 
         Ok(())
     }
+
+    #[test]
+    fn to_levenshtein() -> Result<()> {
+        let string1_array =
+            Arc::new(StringArray::from(vec!["123", "abc", "xyz", "kitten"]));
+        let string2_array =
+            Arc::new(StringArray::from(vec!["321", "def", "zyx", "sitting"]));
+        let res = levenshtein::<i32>(&[string1_array, string2_array]).unwrap();
+        let result =
+            as_int32_array(&res).expect("failed to initialized function 
levenshtein");
+        let expected = Int32Array::from(vec![2, 3, 2, 3]);
+        assert_eq!(&expected, result);
+
+        Ok(())
+    }
 }
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 66c34c7a12..a5c3d3b603 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -639,6 +639,7 @@ enum ScalarFunction {
   OverLay = 121;
   Range = 122;
   ArrayPopFront = 123;
+  Levenshtein = 124;
 }
 
 message ScalarFunctionNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 628adcc411..3faacca18c 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -20938,6 +20938,7 @@ impl serde::Serialize for ScalarFunction {
             Self::OverLay => "OverLay",
             Self::Range => "Range",
             Self::ArrayPopFront => "ArrayPopFront",
+            Self::Levenshtein => "Levenshtein",
         };
         serializer.serialize_str(variant)
     }
@@ -21073,6 +21074,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
             "OverLay",
             "Range",
             "ArrayPopFront",
+            "Levenshtein",
         ];
 
         struct GeneratedVisitor;
@@ -21237,6 +21239,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
                     "OverLay" => Ok(ScalarFunction::OverLay),
                     "Range" => Ok(ScalarFunction::Range),
                     "ArrayPopFront" => Ok(ScalarFunction::ArrayPopFront),
+                    "Levenshtein" => Ok(ScalarFunction::Levenshtein),
                     _ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
                 }
             }
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 317b888447..2555a31f6f 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2570,6 +2570,7 @@ pub enum ScalarFunction {
     OverLay = 121,
     Range = 122,
     ArrayPopFront = 123,
+    Levenshtein = 124,
 }
 impl ScalarFunction {
     /// String value of the enum field names used in the ProtoBuf definition.
@@ -2702,6 +2703,7 @@ impl ScalarFunction {
             ScalarFunction::OverLay => "OverLay",
             ScalarFunction::Range => "Range",
             ScalarFunction::ArrayPopFront => "ArrayPopFront",
+            ScalarFunction::Levenshtein => "Levenshtein",
         }
     }
     /// Creates an enum from field names used in the ProtoBuf definition.
@@ -2831,6 +2833,7 @@ impl ScalarFunction {
             "OverLay" => Some(Self::OverLay),
             "Range" => Some(Self::Range),
             "ArrayPopFront" => Some(Self::ArrayPopFront),
+            "Levenshtein" => Some(Self::Levenshtein),
             _ => None,
         }
     }
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index 94c9f98066..f14da70485 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -50,7 +50,7 @@ use datafusion_expr::{
     date_part, date_trunc, decode, degrees, digest, encode, exp,
     expr::{self, InList, Sort, WindowFunction},
     factorial, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, 
lcm, left,
-    ln, log, log10, log2,
+    levenshtein, ln, log, log10, log2,
     logical_plan::{PlanType, StringifiedPlan},
     lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, 
power,
     radians, random, regexp_match, regexp_replace, repeat, replace, reverse, 
right,
@@ -549,6 +549,7 @@ impl From<&protobuf::ScalarFunction> for 
BuiltinScalarFunction {
             ScalarFunction::Iszero => Self::Iszero,
             ScalarFunction::ArrowTypeof => Self::ArrowTypeof,
             ScalarFunction::OverLay => Self::OverLay,
+            ScalarFunction::Levenshtein => Self::Levenshtein,
         }
     }
 }
@@ -1630,6 +1631,10 @@ pub fn parse_expr(
                         ))
                     }
                 }
+                ScalarFunction::Levenshtein => Ok(levenshtein(
+                    parse_expr(&args[0], registry)?,
+                    parse_expr(&args[1], registry)?,
+                )),
                 ScalarFunction::ToHex => Ok(to_hex(parse_expr(&args[0], 
registry)?)),
                 ScalarFunction::ToTimestampMillis => {
                     Ok(to_timestamp_millis(parse_expr(&args[0], registry)?))
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs 
b/datafusion/proto/src/logical_plan/to_proto.rs
index 649be05b88..de81a1f4ca 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1556,6 +1556,7 @@ impl TryFrom<&BuiltinScalarFunction> for 
protobuf::ScalarFunction {
             BuiltinScalarFunction::Iszero => Self::Iszero,
             BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof,
             BuiltinScalarFunction::OverLay => Self::OverLay,
+            BuiltinScalarFunction::Levenshtein => Self::Levenshtein,
         };
 
         Ok(scalar_function)
diff --git a/datafusion/sqllogictest/test_files/functions.slt 
b/datafusion/sqllogictest/test_files/functions.slt
index 8f42304384..9c8bb2c5f8 100644
--- a/datafusion/sqllogictest/test_files/functions.slt
+++ b/datafusion/sqllogictest/test_files/functions.slt
@@ -788,7 +788,7 @@ INSERT INTO products (product_id, product_name, price) 
VALUES
 (1, 'OldBrand Product 1', 19.99),
 (2, 'OldBrand Product 2', 29.99),
 (3, 'OldBrand Product 3', 39.99),
-(4, 'OldBrand Product 4', 49.99)                    
+(4, 'OldBrand Product 4', 49.99)
 
 query ITR
 SELECT * REPLACE (price*2 AS price) FROM products
@@ -857,3 +857,23 @@ NULL
 NULL
 Thomxas
 NULL
+
+query I
+SELECT levenshtein('kitten', 'sitting')
+----
+3
+
+query I
+SELECT levenshtein('kitten', NULL)
+----
+NULL
+
+query ?
+SELECT levenshtein(NULL, 'sitting')
+----
+NULL
+
+query ?
+SELECT levenshtein(NULL, NULL)
+----
+NULL
diff --git a/docs/source/user-guide/sql/scalar_functions.md 
b/docs/source/user-guide/sql/scalar_functions.md
index baaea3926f..f9f45a1b0a 100644
--- a/docs/source/user-guide/sql/scalar_functions.md
+++ b/docs/source/user-guide/sql/scalar_functions.md
@@ -636,6 +636,7 @@ nullif(expression1, expression2)
 - [upper](#upper)
 - [uuid](#uuid)
 - [overlay](#overlay)
+- [levenshtein](#levenshtein)
 
 ### `ascii`
 
@@ -1137,6 +1138,20 @@ overlay(str PLACING substr FROM pos [FOR count])
 - **pos**: the start position to replace of str.
 - **count**: the count of characters to be replaced from start position of 
str. If not specified, will use substr length instead.
 
+### `levenshtein`
+
+Returns the Levenshtein distance between the two given strings.
+For example, `levenshtein('kitten', 'sitting') = 3`
+
+```
+levenshtein(str1, str2)
+```
+
+#### Arguments
+
+- **str1**: String expression to compute Levenshtein distance with str2.
+- **str2**: String expression to compute Levenshtein distance with str1.
+
 ## Binary String Functions
 
 - [decode](#decode)

Reply via email to