This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 99daafd7b7 feat(spark): implement Spark conditional function if 
(#16946)
99daafd7b7 is described below

commit 99daafd7b738831f5d6f95007536e0712a90ba5c
Author: Chen Chongchen <chenkov...@qq.com>
AuthorDate: Sat Aug 30 20:44:20 2025 +0800

    feat(spark): implement Spark conditional function if (#16946)
---
 datafusion/spark/src/function/conditional/if.rs    | 101 ++++++++++++++
 datafusion/spark/src/function/conditional/mod.rs   |  13 +-
 .../test_files/spark/conditional/if.slt            | 147 ++++++++++++++++++++-
 3 files changed, 255 insertions(+), 6 deletions(-)

diff --git a/datafusion/spark/src/function/conditional/if.rs 
b/datafusion/spark/src/function/conditional/if.rs
new file mode 100644
index 0000000000..aee43dd8d0
--- /dev/null
+++ b/datafusion/spark/src/function/conditional/if.rs
@@ -0,0 +1,101 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::datatypes::DataType;
+use datafusion_common::{internal_err, plan_err, Result};
+use datafusion_expr::{
+    binary::try_type_union_resolution, simplify::ExprSimplifyResult, when, 
ColumnarValue,
+    Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
+};
+
+#[derive(Debug, PartialEq, Eq, Hash)]
+pub struct SparkIf {
+    signature: Signature,
+}
+
+impl Default for SparkIf {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl SparkIf {
+    pub fn new() -> Self {
+        Self {
+            signature: Signature::user_defined(Volatility::Immutable),
+        }
+    }
+}
+
+impl ScalarUDFImpl for SparkIf {
+    fn as_any(&self) -> &dyn std::any::Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        "if"
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
+        if arg_types.len() != 3 {
+            return plan_err!(
+                "Function 'if' expects 3 arguments but received {}",
+                arg_types.len()
+            );
+        }
+
+        if arg_types[0] != DataType::Boolean && arg_types[0] != DataType::Null 
{
+            return plan_err!(
+                "For function 'if' {} is not a boolean or null",
+                arg_types[0]
+            );
+        }
+
+        let target_types = try_type_union_resolution(&arg_types[1..])?;
+        let mut result = vec![DataType::Boolean];
+        result.extend(target_types);
+        Ok(result)
+    }
+
+    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+        Ok(arg_types[1].clone())
+    }
+
+    fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
+        internal_err!("if should have been simplified to case")
+    }
+
+    fn simplify(
+        &self,
+        args: Vec<Expr>,
+        _info: &dyn datafusion_expr::simplify::SimplifyInfo,
+    ) -> Result<ExprSimplifyResult> {
+        let condition = args[0].clone();
+        let then_expr = args[1].clone();
+        let else_expr = args[2].clone();
+
+        // Convert IF(condition, then_expr, else_expr) to
+        // CASE WHEN condition THEN then_expr ELSE else_expr END
+        let case_expr = when(condition, then_expr).otherwise(else_expr)?;
+
+        Ok(ExprSimplifyResult::Simplified(case_expr))
+    }
+}
diff --git a/datafusion/spark/src/function/conditional/mod.rs 
b/datafusion/spark/src/function/conditional/mod.rs
index a87df9a2c8..4301d7642b 100644
--- a/datafusion/spark/src/function/conditional/mod.rs
+++ b/datafusion/spark/src/function/conditional/mod.rs
@@ -16,10 +16,19 @@
 // under the License.
 
 use datafusion_expr::ScalarUDF;
+use datafusion_functions::make_udf_function;
 use std::sync::Arc;
 
-pub mod expr_fn {}
+mod r#if;
+
+make_udf_function!(r#if::SparkIf, r#if);
+
+pub mod expr_fn {
+    use datafusion_functions::export_functions;
+
+    export_functions!((r#if, "If arg1 evaluates to true, then returns arg2; 
otherwise returns arg3", arg1 arg2 arg3));
+}
 
 pub fn functions() -> Vec<Arc<ScalarUDF>> {
-    vec![]
+    vec![r#if()]
 }
diff --git a/datafusion/sqllogictest/test_files/spark/conditional/if.slt 
b/datafusion/sqllogictest/test_files/spark/conditional/if.slt
index 7baedad745..b4380e065b 100644
--- a/datafusion/sqllogictest/test_files/spark/conditional/if.slt
+++ b/datafusion/sqllogictest/test_files/spark/conditional/if.slt
@@ -21,7 +21,146 @@
 # For more information, please see:
 #   https://github.com/apache/datafusion/issues/15914
 
-## Original Query: SELECT if(1 < 2, 'a', 'b');
-## PySpark 3.5.5 Result: {'(IF((1 < 2), a, b))': 'a', 'typeof((IF((1 < 2), a, 
b)))': 'string', 'typeof((1 < 2))': 'boolean', 'typeof(a)': 'string', 
'typeof(b)': 'string'}
-#query
-#SELECT if((1 < 2)::boolean, 'a'::string, 'b'::string);
+## Basic IF function tests
+
+# Test basic true condition
+query T
+SELECT if(true, 'yes', 'no');
+----
+yes
+
+# Test basic false condition
+query T
+SELECT if(false, 'yes', 'no');
+----
+no
+
+# Test with comparison operators
+query T
+SELECT if(1 < 2, 'a', 'b');
+----
+a
+
+query T
+SELECT if(1 > 2, 'a', 'b');
+----
+b
+
+
+## Numeric type tests
+
+# Test with integers
+query I
+SELECT if(true, 10, 20);
+----
+10
+
+query I
+SELECT if(false, 10, 20);
+----
+20
+
+# Test with different integer types
+query I
+SELECT if(true, 100, 200);
+----
+100
+
+## Float type tests
+
+# Test with floating point numbers
+query R
+SELECT if(true, 1.5, 2.5);
+----
+1.5
+
+query R
+SELECT if(false, 1.5, 2.5);
+----
+2.5
+
+## String type tests
+
+# Test with different string values
+query T
+SELECT if(true, 'hello', 'world');
+----
+hello
+
+query T
+SELECT if(false, 'hello', 'world');
+----
+world
+
+## NULL handling tests
+
+# Test with NULL condition
+query T
+SELECT if(NULL, 'yes', 'no');
+----
+no
+
+query T
+SELECT if(NOT NULL, 'yes', 'no');
+----
+no
+
+# Test with NULL true value
+query T
+SELECT if(true, NULL, 'no');
+----
+NULL
+
+# Test with NULL false value
+query T
+SELECT if(false, 'yes', NULL);
+----
+NULL
+
+# Test with all NULL
+query ?
+SELECT if(true, NULL, NULL);
+----
+NULL
+
+## Type coercion tests
+
+# Test integer to float coercion
+query R
+SELECT if(true, 10, 20.5);
+----
+10
+
+query R
+SELECT if(false, 10, 20.5);
+----
+20.5
+
+# Test float to integer coercion
+query R
+SELECT if(true, 10.5, 20);
+----
+10.5
+
+query R
+SELECT if(false, 10.5, 20);
+----
+20
+
+statement error Int64 is not a boolean or null
+SELECT if(1, 10.5, 20);
+
+
+statement error Utf8 is not a boolean or null
+SELECT if('x', 10.5, 20);
+
+query II
+SELECT v, IF(v < 0, 10/0, 1) FROM (VALUES (1), (2)) t(v)
+----
+1 1
+2 1
+
+query I
+SELECT IF(true, 1 / 1, 1 / 0);
+----
+1


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to