Copilot commented on code in PR #20927: URL: https://github.com/apache/datafusion/pull/20927#discussion_r2932341299
########## datafusion/spark/src/function/string/levenshtein.rs: ########## @@ -0,0 +1,576 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{as_generic_string_array, as_int32_array, as_string_view_array}; +use datafusion_common::types::{logical_int32, logical_string, NativeType}; +use datafusion_common::utils::datafusion_strsim; +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::type_coercion::binary::{binary_to_string_coercion, string_coercion}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, TypeSignatureClass, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +/// Spark-compatible `levenshtein` function. +/// +/// Differs from DataFusion core's `levenshtein` in that it supports an optional +/// third argument `threshold`. When the computed Levenshtein distance exceeds +/// the threshold, the function returns -1 instead of the actual distance. +/// +/// ```sql +/// levenshtein('kitten', 'sitting') -- returns 3 +/// levenshtein('kitten', 'sitting', 2) -- returns -1 (distance 3 > threshold 2) +/// levenshtein('kitten', 'sitting', 4) -- returns 3 (distance 3 <= threshold 4) +/// ``` +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkLevenshtein { + signature: Signature, +} + +impl Default for SparkLevenshtein { + fn default() -> Self { + Self::new() + } +} + +impl SparkLevenshtein { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ]), + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ), + ]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkLevenshtein { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "levenshtein" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { + if let Some(coercion_data_type) = + string_coercion(&arg_types[0], &arg_types[1]) + .or_else(|| binary_to_string_coercion(&arg_types[0], &arg_types[1])) + { + match coercion_data_type { + DataType::LargeUtf8 => Ok(DataType::Int64), + DataType::Utf8 | DataType::Utf8View => Ok(DataType::Int32), + other => exec_err!( + "levenshtein requires Utf8, LargeUtf8 or Utf8View, got {other}" + ), + } + } else { + exec_err!( + "Unsupported data types for levenshtein. Expected Utf8, LargeUtf8 or Utf8View" + ) + } + } Review Comment: `return_type` indexes `arg_types[0]`/`arg_types[1]` without validating the slice length. If this UDF is ever invoked/planned with an unexpected arity (e.g., planner bug, misuse in a lower-level API, or future signature changes), this will panic instead of returning a `Result` error. Add an explicit arity check (2 or 3) and return an exec_err on mismatch. ########## datafusion/spark/src/function/string/mod.rs: ########## @@ -84,6 +86,11 @@ pub mod expr_fn { "Returns the character length of string data or number of bytes of binary data. The length of string data includes the trailing spaces. The length of binary data includes binary zeros.", arg1 )); + export_functions!(( + levenshtein, + "Returns the Levenshtein distance between two strings. Optionally accepts a threshold; returns -1 if the distance exceeds it.", + str1 str2 + )); export_functions!(( Review Comment: The expr_fn export lists only `str1 str2`, but the UDF signature supports an optional third argument `threshold`. This can lead to an expr_fn API that doesn’t expose the 3-arg form (or generates misleading docs/metadata). Consider exporting a 3-arg variant (or a variadic definition if supported by the macro) so callers can build `levenshtein(str1, str2, threshold)` expressions via `expr_fn` as well. ########## datafusion/sqllogictest/test_files/spark/string/levenshtein.slt: ########## @@ -21,12 +21,167 @@ # For more information, please see: # https://github.com/apache/datafusion/issues/15914 -## Original Query: SELECT levenshtein('kitten', 'sitting'); -## PySpark 3.5.5 Result: {'levenshtein(kitten, sitting)': 3, 'typeof(levenshtein(kitten, sitting))': 'int', 'typeof(kitten)': 'string', 'typeof(sitting)': 'string'} -#query -#SELECT levenshtein('kitten'::string, 'sitting'::string); - -## Original Query: SELECT levenshtein('kitten', 'sitting', 2); -## PySpark 3.5.5 Result: {'levenshtein(kitten, sitting, 2)': -1, 'typeof(levenshtein(kitten, sitting, 2))': 'int', 'typeof(kitten)': 'string', 'typeof(sitting)': 'string', 'typeof(2)': 'int'} -#query -#SELECT levenshtein('kitten'::string, 'sitting'::string, 2::int); +## ── Basic usage ───────────────────────────────────────────── + +## Basic distance +query I +SELECT levenshtein('kitten', 'sitting'); +---- +3 + +## Identical strings +query I +SELECT levenshtein('hello', 'hello'); +---- +0 + +## Empty string vs non-empty +query I +SELECT levenshtein('', 'abc'); +---- +3 + +## Both empty strings +query I +SELECT levenshtein('', ''); +---- +0 + +## Single character difference +query I +SELECT levenshtein('abc', 'adc'); +---- +1 + +## ── Threshold (3-argument form) ───────────────────────────── + +## Distance within threshold +query I +SELECT levenshtein('kitten', 'sitting', 4); +---- +3 + +## Distance exceeds threshold → returns -1 +query I +SELECT levenshtein('kitten', 'sitting', 2); +---- +-1 + +## Distance equals threshold (boundary) → returns distance +query I +SELECT levenshtein('kitten', 'sitting', 3); +---- +3 + +## Threshold zero with different strings +query I +SELECT levenshtein('abc', 'def', 0); +---- +-1 + +## Threshold zero with identical strings +query I +SELECT levenshtein('abc', 'abc', 0); +---- +0 + +## ── Null handling ─────────────────────────────────────────── + +## First argument null +query I +SELECT levenshtein(CAST(NULL AS STRING), 'hello'); +---- +NULL + +## Second argument null +query I +SELECT levenshtein('hello', CAST(NULL AS STRING)); +---- +NULL + +## Both arguments null +query I +SELECT levenshtein(CAST(NULL AS STRING), CAST(NULL AS STRING)); +---- +NULL + +## Null threshold +query I +SELECT levenshtein('kitten', 'sitting', CAST(NULL AS INT)); +---- +NULL Review Comment: The PR description states 'NULL threshold is treated as `0` (Spark behavior)', but this SLT expects `NULL` for a NULL scalar threshold. This conflicts with both the PR description and the per-row SLT cases later that expect NULL threshold → treated as 0. Please reconcile the intended Spark behavior (verify in Spark) and then make the SLT expectations consistent with the chosen semantics. ########## datafusion/spark/src/function/string/levenshtein.rs: ########## @@ -0,0 +1,576 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{as_generic_string_array, as_int32_array, as_string_view_array}; +use datafusion_common::types::{logical_int32, logical_string, NativeType}; +use datafusion_common::utils::datafusion_strsim; +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::type_coercion::binary::{binary_to_string_coercion, string_coercion}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, TypeSignatureClass, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +/// Spark-compatible `levenshtein` function. +/// +/// Differs from DataFusion core's `levenshtein` in that it supports an optional +/// third argument `threshold`. When the computed Levenshtein distance exceeds +/// the threshold, the function returns -1 instead of the actual distance. +/// +/// ```sql +/// levenshtein('kitten', 'sitting') -- returns 3 +/// levenshtein('kitten', 'sitting', 2) -- returns -1 (distance 3 > threshold 2) +/// levenshtein('kitten', 'sitting', 4) -- returns 3 (distance 3 <= threshold 4) +/// ``` +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkLevenshtein { + signature: Signature, +} + +impl Default for SparkLevenshtein { + fn default() -> Self { + Self::new() + } +} + +impl SparkLevenshtein { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ]), + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ), + ]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkLevenshtein { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "levenshtein" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { + if let Some(coercion_data_type) = + string_coercion(&arg_types[0], &arg_types[1]) + .or_else(|| binary_to_string_coercion(&arg_types[0], &arg_types[1])) + { + match coercion_data_type { + DataType::LargeUtf8 => Ok(DataType::Int64), + DataType::Utf8 | DataType::Utf8View => Ok(DataType::Int32), + other => exec_err!( + "levenshtein requires Utf8, LargeUtf8 or Utf8View, got {other}" + ), + } Review Comment: For Spark-compatibility, it’s worth re-checking the return type for `LargeUtf8`. Spark’s `levenshtein` returns an `INT` (32-bit) regardless of input string storage, but this implementation returns `Int64` when the coerced type is `LargeUtf8`. That can produce type mismatches vs Spark semantics in downstream expressions. If the goal is strict Spark parity, consider returning `Int32` consistently and define behavior for distances exceeding i32 (error, clamp, or documented divergence). ########## datafusion/spark/src/function/string/levenshtein.rs: ########## @@ -0,0 +1,576 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{as_generic_string_array, as_int32_array, as_string_view_array}; +use datafusion_common::types::{logical_int32, logical_string, NativeType}; +use datafusion_common::utils::datafusion_strsim; +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::type_coercion::binary::{binary_to_string_coercion, string_coercion}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, TypeSignatureClass, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +/// Spark-compatible `levenshtein` function. +/// +/// Differs from DataFusion core's `levenshtein` in that it supports an optional +/// third argument `threshold`. When the computed Levenshtein distance exceeds +/// the threshold, the function returns -1 instead of the actual distance. +/// +/// ```sql +/// levenshtein('kitten', 'sitting') -- returns 3 +/// levenshtein('kitten', 'sitting', 2) -- returns -1 (distance 3 > threshold 2) +/// levenshtein('kitten', 'sitting', 4) -- returns 3 (distance 3 <= threshold 4) +/// ``` +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkLevenshtein { + signature: Signature, +} + +impl Default for SparkLevenshtein { + fn default() -> Self { + Self::new() + } +} + +impl SparkLevenshtein { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ]), + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ), + ]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkLevenshtein { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "levenshtein" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { + if let Some(coercion_data_type) = + string_coercion(&arg_types[0], &arg_types[1]) + .or_else(|| binary_to_string_coercion(&arg_types[0], &arg_types[1])) + { + match coercion_data_type { + DataType::LargeUtf8 => Ok(DataType::Int64), + DataType::Utf8 | DataType::Utf8View => Ok(DataType::Int32), + other => exec_err!( + "levenshtein requires Utf8, LargeUtf8 or Utf8View, got {other}" + ), + } + } else { + exec_err!( + "Unsupported data types for levenshtein. Expected Utf8, LargeUtf8 or Utf8View" + ) + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { + let ScalarFunctionArgs { args, .. } = args; + + // Determine the coerced string type (handles mixed Utf8 + LargeUtf8) + let coerced_type = string_coercion(&args[0].data_type(), &args[1].data_type()) + .or_else(|| { + binary_to_string_coercion(&args[0].data_type(), &args[1].data_type()) + }) + .unwrap_or(DataType::Utf8); + + // Spark returns NULL when any scalar argument is NULL. + let null_int = |dt: &DataType| match dt { + DataType::LargeUtf8 => ColumnarValue::Scalar(ScalarValue::Int64(None)), + _ => ColumnarValue::Scalar(ScalarValue::Int32(None)), + }; + for arg in &args { + if matches!(arg, ColumnarValue::Scalar(s) if s.is_null()) { Review Comment: This early-return makes `levenshtein(str1, str2, NULL)` return NULL when the NULL threshold is a scalar. However, the array path treats NULL thresholds as `0` (see per-row handling in `spark_levenshtein`). That yields inconsistent semantics between scalar vs column thresholds. Please align behavior to Spark’s actual semantics (either propagate NULL consistently, or treat NULL as 0 consistently) and update the inline comment + tests accordingly. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
