This is an automated email from the ASF dual-hosted git repository. agrove pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push: new 99daafd7b7 feat(spark): implement Spark conditional function if (#16946) 99daafd7b7 is described below commit 99daafd7b738831f5d6f95007536e0712a90ba5c Author: Chen Chongchen <chenkov...@qq.com> AuthorDate: Sat Aug 30 20:44:20 2025 +0800 feat(spark): implement Spark conditional function if (#16946) --- datafusion/spark/src/function/conditional/if.rs | 101 ++++++++++++++ datafusion/spark/src/function/conditional/mod.rs | 13 +- .../test_files/spark/conditional/if.slt | 147 ++++++++++++++++++++- 3 files changed, 255 insertions(+), 6 deletions(-) diff --git a/datafusion/spark/src/function/conditional/if.rs b/datafusion/spark/src/function/conditional/if.rs new file mode 100644 index 0000000000..aee43dd8d0 --- /dev/null +++ b/datafusion/spark/src/function/conditional/if.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType; +use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_expr::{ + binary::try_type_union_resolution, simplify::ExprSimplifyResult, when, ColumnarValue, + Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkIf { + signature: Signature, +} + +impl Default for SparkIf { + fn default() -> Self { + Self::new() + } +} + +impl SparkIf { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkIf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "if" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { + if arg_types.len() != 3 { + return plan_err!( + "Function 'if' expects 3 arguments but received {}", + arg_types.len() + ); + } + + if arg_types[0] != DataType::Boolean && arg_types[0] != DataType::Null { + return plan_err!( + "For function 'if' {} is not a boolean or null", + arg_types[0] + ); + } + + let target_types = try_type_union_resolution(&arg_types[1..])?; + let mut result = vec![DataType::Boolean]; + result.extend(target_types); + Ok(result) + } + + fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { + Ok(arg_types[1].clone()) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> { + internal_err!("if should have been simplified to case") + } + + fn simplify( + &self, + args: Vec<Expr>, + _info: &dyn datafusion_expr::simplify::SimplifyInfo, + ) -> Result<ExprSimplifyResult> { + let condition = args[0].clone(); + let then_expr = args[1].clone(); + let else_expr = args[2].clone(); + + // Convert IF(condition, then_expr, else_expr) to + // CASE WHEN condition THEN then_expr ELSE else_expr END + let case_expr = when(condition, then_expr).otherwise(else_expr)?; + + Ok(ExprSimplifyResult::Simplified(case_expr)) + } +} diff --git a/datafusion/spark/src/function/conditional/mod.rs b/datafusion/spark/src/function/conditional/mod.rs index a87df9a2c8..4301d7642b 100644 --- a/datafusion/spark/src/function/conditional/mod.rs +++ b/datafusion/spark/src/function/conditional/mod.rs @@ -16,10 +16,19 @@ // under the License. use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; use std::sync::Arc; -pub mod expr_fn {} +mod r#if; + +make_udf_function!(r#if::SparkIf, r#if); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!((r#if, "If arg1 evaluates to true, then returns arg2; otherwise returns arg3", arg1 arg2 arg3)); +} pub fn functions() -> Vec<Arc<ScalarUDF>> { - vec![] + vec![r#if()] } diff --git a/datafusion/sqllogictest/test_files/spark/conditional/if.slt b/datafusion/sqllogictest/test_files/spark/conditional/if.slt index 7baedad745..b4380e065b 100644 --- a/datafusion/sqllogictest/test_files/spark/conditional/if.slt +++ b/datafusion/sqllogictest/test_files/spark/conditional/if.slt @@ -21,7 +21,146 @@ # For more information, please see: # https://github.com/apache/datafusion/issues/15914 -## Original Query: SELECT if(1 < 2, 'a', 'b'); -## PySpark 3.5.5 Result: {'(IF((1 < 2), a, b))': 'a', 'typeof((IF((1 < 2), a, b)))': 'string', 'typeof((1 < 2))': 'boolean', 'typeof(a)': 'string', 'typeof(b)': 'string'} -#query -#SELECT if((1 < 2)::boolean, 'a'::string, 'b'::string); +## Basic IF function tests + +# Test basic true condition +query T +SELECT if(true, 'yes', 'no'); +---- +yes + +# Test basic false condition +query T +SELECT if(false, 'yes', 'no'); +---- +no + +# Test with comparison operators +query T +SELECT if(1 < 2, 'a', 'b'); +---- +a + +query T +SELECT if(1 > 2, 'a', 'b'); +---- +b + + +## Numeric type tests + +# Test with integers +query I +SELECT if(true, 10, 20); +---- +10 + +query I +SELECT if(false, 10, 20); +---- +20 + +# Test with different integer types +query I +SELECT if(true, 100, 200); +---- +100 + +## Float type tests + +# Test with floating point numbers +query R +SELECT if(true, 1.5, 2.5); +---- +1.5 + +query R +SELECT if(false, 1.5, 2.5); +---- +2.5 + +## String type tests + +# Test with different string values +query T +SELECT if(true, 'hello', 'world'); +---- +hello + +query T +SELECT if(false, 'hello', 'world'); +---- +world + +## NULL handling tests + +# Test with NULL condition +query T +SELECT if(NULL, 'yes', 'no'); +---- +no + +query T +SELECT if(NOT NULL, 'yes', 'no'); +---- +no + +# Test with NULL true value +query T +SELECT if(true, NULL, 'no'); +---- +NULL + +# Test with NULL false value +query T +SELECT if(false, 'yes', NULL); +---- +NULL + +# Test with all NULL +query ? +SELECT if(true, NULL, NULL); +---- +NULL + +## Type coercion tests + +# Test integer to float coercion +query R +SELECT if(true, 10, 20.5); +---- +10 + +query R +SELECT if(false, 10, 20.5); +---- +20.5 + +# Test float to integer coercion +query R +SELECT if(true, 10.5, 20); +---- +10.5 + +query R +SELECT if(false, 10.5, 20); +---- +20 + +statement error Int64 is not a boolean or null +SELECT if(1, 10.5, 20); + + +statement error Utf8 is not a boolean or null +SELECT if('x', 10.5, 20); + +query II +SELECT v, IF(v < 0, 10/0, 1) FROM (VALUES (1), (2)) t(v) +---- +1 1 +2 1 + +query I +SELECT IF(true, 1 / 1, 1 / 0); +---- +1 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org