Copilot commented on code in PR #22890: URL: https://github.com/apache/datafusion/pull/22890#discussion_r3390797746
########## datafusion/spark/src/function/string/levenshtein.rs: ########## @@ -0,0 +1,594 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{ + as_generic_string_array, as_int32_array, as_string_view_array, +}; +use datafusion_common::types::{NativeType, logical_int32, logical_string}; +use datafusion_common::utils::datafusion_strsim; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::type_coercion::binary::{ + binary_to_string_coercion, string_coercion, +}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +/// Spark-compatible `levenshtein` function. +/// +/// Differs from DataFusion core's `levenshtein` in that it supports an optional +/// third argument `threshold`. When the computed Levenshtein distance exceeds +/// the threshold, the function returns -1 instead of the actual distance. +/// +/// ```sql +/// levenshtein('kitten', 'sitting') -- returns 3 +/// levenshtein('kitten', 'sitting', 2) -- returns -1 (distance 3 > threshold 2) +/// levenshtein('kitten', 'sitting', 4) -- returns 3 (distance 3 <= threshold 4) +/// ``` +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkLevenshtein { + signature: Signature, +} + +impl Default for SparkLevenshtein { + fn default() -> Self { + Self::new() + } +} + +impl SparkLevenshtein { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ]), + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ), + ]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkLevenshtein { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "levenshtein" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { + if arg_types.len() != 2 && arg_types.len() != 3 { + return exec_err!( + "levenshtein expects 2 or 3 arguments, got {}", + arg_types.len() + ); + } + if let Some(coercion_data_type) = string_coercion(&arg_types[0], &arg_types[1]) + .or_else(|| binary_to_string_coercion(&arg_types[0], &arg_types[1])) + { + match coercion_data_type { + DataType::LargeUtf8 => Ok(DataType::Int64), + DataType::Utf8 | DataType::Utf8View => Ok(DataType::Int32), + other => exec_err!( + "levenshtein requires Utf8, LargeUtf8 or Utf8View, got {other}" + ), + } + } else { + exec_err!( + "Unsupported data types for levenshtein. Expected Utf8, LargeUtf8 or Utf8View" + ) + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { + let ScalarFunctionArgs { args, .. } = args; + + // Determine the coerced string type (handles mixed Utf8 + LargeUtf8) + let coerced_type = string_coercion(&args[0].data_type(), &args[1].data_type()) + .or_else(|| { + binary_to_string_coercion(&args[0].data_type(), &args[1].data_type()) + }) + .unwrap_or(DataType::Utf8); Review Comment: Defaulting to `DataType::Utf8` when coercion fails is inconsistent with `return_type`, which returns an error for unsupported types. This can mask type errors and produce confusing runtime behavior; return an `exec_err!` when both coercions return `None` (instead of `unwrap_or(DataType::Utf8)`), matching `return_type`’s logic. ########## datafusion/spark/src/function/string/mod.rs: ########## @@ -92,6 +94,11 @@ pub mod expr_fn { "Returns the character length of string data or number of bytes of binary data. The length of string data includes the trailing spaces. The length of binary data includes binary zeros.", arg1 )); + export_functions!(( + levenshtein, + "Returns the Levenshtein distance between two strings. Optionally accepts a threshold; returns -1 if the distance exceeds it.", + str1 str2 threshold + )); Review Comment: The `expr_fn` export lists `threshold` as a required parameter name, but the SQL function supports both 2-arg and 3-arg forms. If `export_functions!` generates fixed-arity Rust helpers, this likely exposes only a 3-arg builder (making the 2-arg form unavailable/misleading in the expression API). Consider exporting two arities (separate exports/wrappers) or using whichever macro pattern the codebase uses for optional arguments so both `levenshtein(str1, str2)` and `levenshtein(str1, str2, threshold)` are supported. ########## datafusion/spark/src/function/string/levenshtein.rs: ########## @@ -0,0 +1,594 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{ + as_generic_string_array, as_int32_array, as_string_view_array, +}; +use datafusion_common::types::{NativeType, logical_int32, logical_string}; +use datafusion_common::utils::datafusion_strsim; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::type_coercion::binary::{ + binary_to_string_coercion, string_coercion, +}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +/// Spark-compatible `levenshtein` function. +/// +/// Differs from DataFusion core's `levenshtein` in that it supports an optional +/// third argument `threshold`. When the computed Levenshtein distance exceeds +/// the threshold, the function returns -1 instead of the actual distance. +/// +/// ```sql +/// levenshtein('kitten', 'sitting') -- returns 3 +/// levenshtein('kitten', 'sitting', 2) -- returns -1 (distance 3 > threshold 2) +/// levenshtein('kitten', 'sitting', 4) -- returns 3 (distance 3 <= threshold 4) +/// ``` +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkLevenshtein { + signature: Signature, +} + +impl Default for SparkLevenshtein { + fn default() -> Self { + Self::new() + } +} + +impl SparkLevenshtein { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ]), + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ), + ]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkLevenshtein { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "levenshtein" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { + if arg_types.len() != 2 && arg_types.len() != 3 { + return exec_err!( + "levenshtein expects 2 or 3 arguments, got {}", + arg_types.len() + ); + } + if let Some(coercion_data_type) = string_coercion(&arg_types[0], &arg_types[1]) + .or_else(|| binary_to_string_coercion(&arg_types[0], &arg_types[1])) + { + match coercion_data_type { + DataType::LargeUtf8 => Ok(DataType::Int64), + DataType::Utf8 | DataType::Utf8View => Ok(DataType::Int32), + other => exec_err!( + "levenshtein requires Utf8, LargeUtf8 or Utf8View, got {other}" + ), + } + } else { + exec_err!( + "Unsupported data types for levenshtein. Expected Utf8, LargeUtf8 or Utf8View" + ) + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { + let ScalarFunctionArgs { args, .. } = args; + + // Determine the coerced string type (handles mixed Utf8 + LargeUtf8) + let coerced_type = string_coercion(&args[0].data_type(), &args[1].data_type()) + .or_else(|| { + binary_to_string_coercion(&args[0].data_type(), &args[1].data_type()) + }) + .unwrap_or(DataType::Utf8); Review Comment: `invoke_with_args` indexes `args[0]` / `args[1]` without validating arity first; if the UDF is invoked with <2 args this will panic. Add the same `2 or 3 arguments` check used in `return_type` (and `spark_levenshtein`) at the start of `invoke_with_args` and return an `exec_err!` on mismatch. ########## datafusion/sqllogictest/test_files/spark/string/levenshtein.slt: ########## @@ -21,12 +21,224 @@ # For more information, please see: # https://github.com/apache/datafusion/issues/15914 -## Original Query: SELECT levenshtein('kitten', 'sitting'); -## PySpark 3.5.5 Result: {'levenshtein(kitten, sitting)': 3, 'typeof(levenshtein(kitten, sitting))': 'int', 'typeof(kitten)': 'string', 'typeof(sitting)': 'string'} -#query -#SELECT levenshtein('kitten'::string, 'sitting'::string); - -## Original Query: SELECT levenshtein('kitten', 'sitting', 2); -## PySpark 3.5.5 Result: {'levenshtein(kitten, sitting, 2)': -1, 'typeof(levenshtein(kitten, sitting, 2))': 'int', 'typeof(kitten)': 'string', 'typeof(sitting)': 'string', 'typeof(2)': 'int'} -#query -#SELECT levenshtein('kitten'::string, 'sitting'::string, 2::int); +## ── Basic usage ───────────────────────────────────────────── + +## Basic distance +query I +SELECT levenshtein('kitten', 'sitting'); +---- +3 + +## Identical strings +query I +SELECT levenshtein('hello', 'hello'); +---- +0 + +## Empty string vs non-empty +query I +SELECT levenshtein('', 'abc'); +---- +3 + +## Both empty strings +query I +SELECT levenshtein('', ''); +---- +0 + +## Single character difference +query I +SELECT levenshtein('abc', 'adc'); +---- +1 + +## ── Threshold (3-argument form) ───────────────────────────── + +## Distance within threshold +query I +SELECT levenshtein('kitten', 'sitting', 4); +---- +3 + +## Distance exceeds threshold → returns -1 +query I +SELECT levenshtein('kitten', 'sitting', 2); +---- +-1 + +## Distance equals threshold (boundary) → returns distance +query I +SELECT levenshtein('kitten', 'sitting', 3); +---- +3 + +## Threshold zero with different strings +query I +SELECT levenshtein('abc', 'def', 0); +---- +-1 + +## Threshold zero with identical strings +query I +SELECT levenshtein('abc', 'abc', 0); +---- +0 + +## ── Null handling ─────────────────────────────────────────── + +## First argument null +query I +SELECT levenshtein(CAST(NULL AS STRING), 'hello'); +---- +NULL + +## Second argument null +query I +SELECT levenshtein('hello', CAST(NULL AS STRING)); +---- +NULL + +## Both arguments null +query I +SELECT levenshtein(CAST(NULL AS STRING), CAST(NULL AS STRING)); +---- +NULL + +## Null threshold +query I +SELECT levenshtein('kitten', 'sitting', CAST(NULL AS INT)); +---- +NULL + +## ── Unicode and special characters ────────────────────────── + +## Unicode strings +query I +SELECT levenshtein('café', 'cafe'); +---- +1 + +## Strings with spaces +query I +SELECT levenshtein('hello world', 'hello world!'); +---- +1 + +## ── Column expressions ────────────────────────────────────── + +## Levenshtein on columns from inline table +query I +SELECT levenshtein(s1, s2) AS result FROM VALUES ('abc', 'abc'), ('abc', 'def'), ('kitten', 'sitting') AS t(s1, s2); +---- +0 +3 +3 + +## Threshold on columns from inline table +query I +SELECT levenshtein(s1, s2, 2) AS result FROM VALUES ('abc', 'abc'), ('abc', 'def'), ('kitten', 'sitting') AS t(s1, s2); +---- +0 +-1 +-1 + +## ── Per-row threshold ─────────────────────────────────────── + +## Different threshold per row +query I +SELECT levenshtein(s1, s2, t) AS result FROM VALUES ('abc', 'def', 2), ('abc', 'def', 5), ('abc', 'def', 3) AS t(s1, s2, t); Review Comment: This query uses `t` as both the table alias and a column name (`AS t(s1, s2, t)`), which is unnecessarily confusing and can make failures harder to interpret. Rename either the alias (e.g., `v`) or the threshold column (e.g., `threshold`) to avoid the collision. ########## datafusion/spark/src/function/string/levenshtein.rs: ########## @@ -0,0 +1,594 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{ + as_generic_string_array, as_int32_array, as_string_view_array, +}; +use datafusion_common::types::{NativeType, logical_int32, logical_string}; +use datafusion_common::utils::datafusion_strsim; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::type_coercion::binary::{ + binary_to_string_coercion, string_coercion, +}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +/// Spark-compatible `levenshtein` function. +/// +/// Differs from DataFusion core's `levenshtein` in that it supports an optional +/// third argument `threshold`. When the computed Levenshtein distance exceeds +/// the threshold, the function returns -1 instead of the actual distance. +/// +/// ```sql +/// levenshtein('kitten', 'sitting') -- returns 3 +/// levenshtein('kitten', 'sitting', 2) -- returns -1 (distance 3 > threshold 2) +/// levenshtein('kitten', 'sitting', 4) -- returns 3 (distance 3 <= threshold 4) +/// ``` +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkLevenshtein { + signature: Signature, +} + +impl Default for SparkLevenshtein { + fn default() -> Self { + Self::new() + } +} + +impl SparkLevenshtein { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ]), + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ), + ]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkLevenshtein { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "levenshtein" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { + if arg_types.len() != 2 && arg_types.len() != 3 { + return exec_err!( + "levenshtein expects 2 or 3 arguments, got {}", + arg_types.len() + ); + } + if let Some(coercion_data_type) = string_coercion(&arg_types[0], &arg_types[1]) + .or_else(|| binary_to_string_coercion(&arg_types[0], &arg_types[1])) + { + match coercion_data_type { + DataType::LargeUtf8 => Ok(DataType::Int64), + DataType::Utf8 | DataType::Utf8View => Ok(DataType::Int32), + other => exec_err!( + "levenshtein requires Utf8, LargeUtf8 or Utf8View, got {other}" + ), + } + } else { + exec_err!( + "Unsupported data types for levenshtein. Expected Utf8, LargeUtf8 or Utf8View" + ) + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { + let ScalarFunctionArgs { args, .. } = args; + + // Determine the coerced string type (handles mixed Utf8 + LargeUtf8) + let coerced_type = string_coercion(&args[0].data_type(), &args[1].data_type()) + .or_else(|| { + binary_to_string_coercion(&args[0].data_type(), &args[1].data_type()) + }) + .unwrap_or(DataType::Utf8); + + // Spark returns NULL when any scalar argument is NULL. + let null_int = |dt: &DataType| match dt { + DataType::LargeUtf8 => ColumnarValue::Scalar(ScalarValue::Int64(None)), + _ => ColumnarValue::Scalar(ScalarValue::Int32(None)), + }; + for arg in &args { + if matches!(arg, ColumnarValue::Scalar(s) if s.is_null()) { + return Ok(null_int(&coerced_type)); + } + } + + match coerced_type { + DataType::Utf8View | DataType::Utf8 => { + make_scalar_function(spark_levenshtein::<i32>, vec![])(&args) + } + DataType::LargeUtf8 => { + make_scalar_function(spark_levenshtein::<i64>, vec![])(&args) + } Review Comment: The unit tests cover mixed `Utf8 + LargeUtf8` for the 2-arg form, but there are no direct unit tests exercising the `Utf8View` path and no unit test covering the `LargeUtf8` path with the 3-arg threshold form. Adding targeted tests for (1) `Utf8View` inputs and (2) `LargeUtf8` inputs with a threshold would help ensure the per-branch logic remains consistent. ########## datafusion/spark/src/function/string/levenshtein.rs: ########## @@ -0,0 +1,594 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{ + as_generic_string_array, as_int32_array, as_string_view_array, +}; +use datafusion_common::types::{NativeType, logical_int32, logical_string}; +use datafusion_common::utils::datafusion_strsim; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::type_coercion::binary::{ + binary_to_string_coercion, string_coercion, +}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +/// Spark-compatible `levenshtein` function. +/// +/// Differs from DataFusion core's `levenshtein` in that it supports an optional +/// third argument `threshold`. When the computed Levenshtein distance exceeds +/// the threshold, the function returns -1 instead of the actual distance. +/// +/// ```sql +/// levenshtein('kitten', 'sitting') -- returns 3 +/// levenshtein('kitten', 'sitting', 2) -- returns -1 (distance 3 > threshold 2) +/// levenshtein('kitten', 'sitting', 4) -- returns 3 (distance 3 <= threshold 4) +/// ``` +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkLevenshtein { + signature: Signature, +} + +impl Default for SparkLevenshtein { + fn default() -> Self { + Self::new() + } +} + +impl SparkLevenshtein { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ]), + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ), + ]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkLevenshtein { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "levenshtein" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { + if arg_types.len() != 2 && arg_types.len() != 3 { + return exec_err!( + "levenshtein expects 2 or 3 arguments, got {}", + arg_types.len() + ); + } + if let Some(coercion_data_type) = string_coercion(&arg_types[0], &arg_types[1]) + .or_else(|| binary_to_string_coercion(&arg_types[0], &arg_types[1])) + { + match coercion_data_type { + DataType::LargeUtf8 => Ok(DataType::Int64), + DataType::Utf8 | DataType::Utf8View => Ok(DataType::Int32), + other => exec_err!( + "levenshtein requires Utf8, LargeUtf8 or Utf8View, got {other}" + ), + } + } else { + exec_err!( + "Unsupported data types for levenshtein. Expected Utf8, LargeUtf8 or Utf8View" + ) + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { + let ScalarFunctionArgs { args, .. } = args; + + // Determine the coerced string type (handles mixed Utf8 + LargeUtf8) + let coerced_type = string_coercion(&args[0].data_type(), &args[1].data_type()) + .or_else(|| { + binary_to_string_coercion(&args[0].data_type(), &args[1].data_type()) + }) + .unwrap_or(DataType::Utf8); + + // Spark returns NULL when any scalar argument is NULL. + let null_int = |dt: &DataType| match dt { + DataType::LargeUtf8 => ColumnarValue::Scalar(ScalarValue::Int64(None)), + _ => ColumnarValue::Scalar(ScalarValue::Int32(None)), + }; + for arg in &args { + if matches!(arg, ColumnarValue::Scalar(s) if s.is_null()) { + return Ok(null_int(&coerced_type)); + } + } + + match coerced_type { + DataType::Utf8View | DataType::Utf8 => { + make_scalar_function(spark_levenshtein::<i32>, vec![])(&args) + } + DataType::LargeUtf8 => { + make_scalar_function(spark_levenshtein::<i64>, vec![])(&args) + } + other => { + exec_err!("Unsupported data type {other:?} for function levenshtein") + } + } + } +} + +/// Spark-compatible Levenshtein distance with optional threshold. +fn spark_levenshtein<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> { + if args.len() < 2 || args.len() > 3 { + return exec_err!("levenshtein expects 2 or 3 arguments, got {}", args.len()); + } + + let str1 = &args[0]; + let str2 = &args[1]; + let threshold = if args.len() == 3 { + Some(as_int32_array(&args[2])?) + } else { + None + }; + + if let Some(coercion_data_type) = string_coercion(str1.data_type(), str2.data_type()) + .or_else(|| binary_to_string_coercion(str1.data_type(), str2.data_type())) + { + let str1 = if str1.data_type() == &coercion_data_type { + Arc::clone(str1) + } else { + arrow::compute::kernels::cast::cast(str1, &coercion_data_type)? + }; + let str2 = if str2.data_type() == &coercion_data_type { + Arc::clone(str2) + } else { + arrow::compute::kernels::cast::cast(str2, &coercion_data_type)? + }; + + match coercion_data_type { + DataType::Utf8View => { + let str1_array = as_string_view_array(&str1)?; + let str2_array = as_string_view_array(&str2)?; + let mut cache = Vec::new(); + + let result = str1_array + .iter() + .zip(str2_array.iter()) + .enumerate() + .map(|(i, (string1, string2))| match (string1, string2) { + (Some(string1), Some(string2)) => { + let dist = datafusion_strsim::levenshtein_with_buffer( + string1, string2, &mut cache, + ) as i32; + match &threshold { + Some(t) => { + let thresh = + if t.is_null(i) { 0 } else { t.value(i) }; + if dist > thresh { + Some(-1i32) + } else { + Some(dist) + } + } + None => Some(dist), + } + } Review Comment: The per-row threshold handling logic is duplicated across the `Utf8View`, `Utf8`, and `LargeUtf8` branches (with only minor type differences). This increases the chance of subtle behavior drift between branches; consider extracting a small helper that maps `(dist, threshold_value_or_none)` to `Option<dist_or_-1>` so each branch shares exactly the same decision logic. ########## datafusion/spark/src/function/string/levenshtein.rs: ########## @@ -0,0 +1,594 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{ + as_generic_string_array, as_int32_array, as_string_view_array, +}; +use datafusion_common::types::{NativeType, logical_int32, logical_string}; +use datafusion_common::utils::datafusion_strsim; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::type_coercion::binary::{ + binary_to_string_coercion, string_coercion, +}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +/// Spark-compatible `levenshtein` function. +/// +/// Differs from DataFusion core's `levenshtein` in that it supports an optional +/// third argument `threshold`. When the computed Levenshtein distance exceeds +/// the threshold, the function returns -1 instead of the actual distance. +/// +/// ```sql +/// levenshtein('kitten', 'sitting') -- returns 3 +/// levenshtein('kitten', 'sitting', 2) -- returns -1 (distance 3 > threshold 2) +/// levenshtein('kitten', 'sitting', 4) -- returns 3 (distance 3 <= threshold 4) +/// ``` +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkLevenshtein { + signature: Signature, +} + +impl Default for SparkLevenshtein { + fn default() -> Self { + Self::new() + } +} + +impl SparkLevenshtein { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ]), + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ), + ]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkLevenshtein { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "levenshtein" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { + if arg_types.len() != 2 && arg_types.len() != 3 { + return exec_err!( + "levenshtein expects 2 or 3 arguments, got {}", + arg_types.len() + ); + } + if let Some(coercion_data_type) = string_coercion(&arg_types[0], &arg_types[1]) + .or_else(|| binary_to_string_coercion(&arg_types[0], &arg_types[1])) + { + match coercion_data_type { + DataType::LargeUtf8 => Ok(DataType::Int64), + DataType::Utf8 | DataType::Utf8View => Ok(DataType::Int32), + other => exec_err!( + "levenshtein requires Utf8, LargeUtf8 or Utf8View, got {other}" + ), + } + } else { + exec_err!( + "Unsupported data types for levenshtein. Expected Utf8, LargeUtf8 or Utf8View" + ) + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { + let ScalarFunctionArgs { args, .. } = args; + + // Determine the coerced string type (handles mixed Utf8 + LargeUtf8) + let coerced_type = string_coercion(&args[0].data_type(), &args[1].data_type()) + .or_else(|| { + binary_to_string_coercion(&args[0].data_type(), &args[1].data_type()) + }) + .unwrap_or(DataType::Utf8); + + // Spark returns NULL when any scalar argument is NULL. + let null_int = |dt: &DataType| match dt { + DataType::LargeUtf8 => ColumnarValue::Scalar(ScalarValue::Int64(None)), + _ => ColumnarValue::Scalar(ScalarValue::Int32(None)), + }; + for arg in &args { + if matches!(arg, ColumnarValue::Scalar(s) if s.is_null()) { + return Ok(null_int(&coerced_type)); + } + } Review Comment: The code explicitly documents NULL propagation for *scalar* arguments, but in the array path NULL thresholds are treated as `0` (see `if t.is_null(i) { 0 } ...`). This mixed behavior is surprising; please add an in-code comment explaining why Spark’s behavior differs between scalar-NULL vs row-NULL threshold (and ideally link to a Spark issue/test/reference), or adjust the array behavior if Spark null-propagates per row. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
