This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 40a20559fb Fix coalesce expr_fn function to take multiple arguments
(#10321)
40a20559fb is described below
commit 40a20559fb6971c23479ad43b2a932d9b182f31a
Author: Andrew Lamb <[email protected]>
AuthorDate: Mon May 6 20:34:14 2024 -0400
Fix coalesce expr_fn function to take multiple arguments (#10321)
---
.../core/tests/dataframe/dataframe_functions.rs | 177 ++++++++++++++++++++-
datafusion/functions/src/core/mod.rs | 79 +++++++--
2 files changed, 244 insertions(+), 12 deletions(-)
diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs
b/datafusion/core/tests/dataframe/dataframe_functions.rs
index 7806461bb1..2ffac6a775 100644
--- a/datafusion/core/tests/dataframe/dataframe_functions.rs
+++ b/datafusion/core/tests/dataframe/dataframe_functions.rs
@@ -30,7 +30,7 @@ use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion::assert_batches_eq;
-use datafusion_common::DFSchema;
+use datafusion_common::{DFSchema, ScalarValue};
use datafusion_expr::expr::Alias;
use datafusion_expr::ExprSchemable;
@@ -161,6 +161,181 @@ async fn test_fn_btrim_with_chars() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn test_fn_nullif() -> Result<()> {
+ let expr = nullif(col("a"), lit("abcDEF"));
+
+ let expected = [
+ "+-------------------------------+",
+ "| nullif(test.a,Utf8(\"abcDEF\")) |",
+ "+-------------------------------+",
+ "| |",
+ "| abc123 |",
+ "| CBAdef |",
+ "| 123AbcDef |",
+ "+-------------------------------+",
+ ];
+
+ assert_fn_batches!(expr, expected);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_fn_arrow_cast() -> Result<()> {
+ let expr = arrow_typeof(arrow_cast(col("b"), lit("Float64")));
+
+ let expected = [
+ "+--------------------------------------------------+",
+ "| arrow_typeof(arrow_cast(test.b,Utf8(\"Float64\"))) |",
+ "+--------------------------------------------------+",
+ "| Float64 |",
+ "| Float64 |",
+ "| Float64 |",
+ "| Float64 |",
+ "+--------------------------------------------------+",
+ ];
+
+ assert_fn_batches!(expr, expected);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_nvl() -> Result<()> {
+ let lit_null = lit(ScalarValue::Utf8(None));
+ // nvl(CASE WHEN a = 'abcDEF' THEN NULL ELSE a END, 'TURNED_NULL')
+ let expr = nvl(
+ when(col("a").eq(lit("abcDEF")), lit_null)
+ .otherwise(col("a"))
+ .unwrap(),
+ lit("TURNED_NULL"),
+ )
+ .alias("nvl_expr");
+
+ let expected = [
+ "+-------------+",
+ "| nvl_expr |",
+ "+-------------+",
+ "| TURNED_NULL |",
+ "| abc123 |",
+ "| CBAdef |",
+ "| 123AbcDef |",
+ "+-------------+",
+ ];
+
+ assert_fn_batches!(expr, expected);
+
+ Ok(())
+}
+#[tokio::test]
+async fn test_nvl2() -> Result<()> {
+ let lit_null = lit(ScalarValue::Utf8(None));
+ // nvl2(CASE WHEN a = 'abcDEF' THEN NULL ELSE a END, 'NON_NUll',
'TURNED_NULL')
+ let expr = nvl2(
+ when(col("a").eq(lit("abcDEF")), lit_null)
+ .otherwise(col("a"))
+ .unwrap(),
+ lit("NON_NULL"),
+ lit("TURNED_NULL"),
+ )
+ .alias("nvl2_expr");
+
+ let expected = [
+ "+-------------+",
+ "| nvl2_expr |",
+ "+-------------+",
+ "| TURNED_NULL |",
+ "| NON_NULL |",
+ "| NON_NULL |",
+ "| NON_NULL |",
+ "+-------------+",
+ ];
+
+ assert_fn_batches!(expr, expected);
+
+ Ok(())
+}
+#[tokio::test]
+async fn test_fn_arrow_typeof() -> Result<()> {
+ let expr = arrow_typeof(col("l"));
+
+ let expected = [
+
"+------------------------------------------------------------------------------------------------------------------+",
+ "| arrow_typeof(test.l)
|",
+
"+------------------------------------------------------------------------------------------------------------------+",
+ "| List(Field { name: \"item\", data_type: Int32, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }) |",
+ "| List(Field { name: \"item\", data_type: Int32, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }) |",
+ "| List(Field { name: \"item\", data_type: Int32, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }) |",
+ "| List(Field { name: \"item\", data_type: Int32, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }) |",
+
"+------------------------------------------------------------------------------------------------------------------+",
+ ];
+
+ assert_fn_batches!(expr, expected);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_fn_struct() -> Result<()> {
+ let expr = r#struct(vec![col("a"), col("b")]);
+
+ let expected = [
+ "+--------------------------+",
+ "| struct(test.a,test.b) |",
+ "+--------------------------+",
+ "| {c0: abcDEF, c1: 1} |",
+ "| {c0: abc123, c1: 10} |",
+ "| {c0: CBAdef, c1: 10} |",
+ "| {c0: 123AbcDef, c1: 100} |",
+ "+--------------------------+",
+ ];
+
+ assert_fn_batches!(expr, expected);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_fn_named_struct() -> Result<()> {
+ let expr = named_struct(vec![lit("column_a"), col("a"), lit("column_b"),
col("b")]);
+
+ let expected = [
+ "+---------------------------------------------------------------+",
+ "| named_struct(Utf8(\"column_a\"),test.a,Utf8(\"column_b\"),test.b)
|",
+ "+---------------------------------------------------------------+",
+ "| {column_a: abcDEF, column_b: 1} |",
+ "| {column_a: abc123, column_b: 10} |",
+ "| {column_a: CBAdef, column_b: 10} |",
+ "| {column_a: 123AbcDef, column_b: 100} |",
+ "+---------------------------------------------------------------+",
+ ];
+
+ assert_fn_batches!(expr, expected);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_fn_coalesce() -> Result<()> {
+ let expr = coalesce(vec![lit(ScalarValue::Utf8(None)), lit("ab")]);
+
+ let expected = [
+ "+---------------------------------+",
+ "| coalesce(Utf8(NULL),Utf8(\"ab\")) |",
+ "+---------------------------------+",
+ "| ab |",
+ "| ab |",
+ "| ab |",
+ "| ab |",
+ "+---------------------------------+",
+ ];
+
+ assert_fn_batches!(expr, expected);
+
+ Ok(())
+}
+
#[tokio::test]
async fn test_fn_approx_median() -> Result<()> {
let expr = approx_median(col("b"));
diff --git a/datafusion/functions/src/core/mod.rs
b/datafusion/functions/src/core/mod.rs
index 753134bdfd..d60e6017dd 100644
--- a/datafusion/functions/src/core/mod.rs
+++ b/datafusion/functions/src/core/mod.rs
@@ -17,6 +17,9 @@
//! "core" DataFusion functions
+use datafusion_expr::ScalarUDF;
+use std::sync::Arc;
+
pub mod arrow_cast;
pub mod arrowtypeof;
pub mod coalesce;
@@ -39,14 +42,68 @@ make_udf_function!(getfield::GetFieldFunc, GET_FIELD,
get_field);
make_udf_function!(coalesce::CoalesceFunc, COALESCE, coalesce);
// Export the functions out of this package, both as expr_fn as well as a list
of functions
-export_functions!(
- (nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it
returns value1. This can be used to perform the inverse operation of the
COALESCE expression."),
- (arrow_cast, arg_1 arg_2, "returns arg_1 cast to the `arrow_type` given
the second argument. This can be used to cast to a specific `arrow_type`."),
- (nvl, arg_1 arg_2, "returns value2 if value1 is NULL; otherwise it returns
value1"),
- (nvl2, arg_1 arg_2 arg_3, "Returns value2 if value1 is not NULL;
otherwise, it returns value3."),
- (arrow_typeof, arg_1, "Returns the Arrow type of the input expression."),
- (r#struct, args, "Returns a struct with the given arguments"),
- (named_struct, args, "Returns a struct with the given names and arguments
pairs"),
- (get_field, arg_1 arg_2, "Returns the value of the field with the given
name from the struct"),
- (coalesce, args, "Returns `coalesce(args...)`, which evaluates to the
value of the first expr which is not NULL")
-);
+pub mod expr_fn {
+ use datafusion_expr::Expr;
+
+ /// returns NULL if value1 equals value2; otherwise it returns value1. This
+ /// can be used to perform the inverse operation of the COALESCE expression
+ pub fn nullif(arg1: Expr, arg2: Expr) -> Expr {
+ super::nullif().call(vec![arg1, arg2])
+ }
+
+ /// returns value1 cast to the `arrow_type` given the second argument. This
+ /// can be used to cast to a specific `arrow_type`.
+ pub fn arrow_cast(arg1: Expr, arg2: Expr) -> Expr {
+ super::arrow_cast().call(vec![arg1, arg2])
+ }
+
+ /// Returns value2 if value1 is NULL; otherwise it returns value1
+ pub fn nvl(arg1: Expr, arg2: Expr) -> Expr {
+ super::nvl().call(vec![arg1, arg2])
+ }
+
+ /// Returns value2 if value1 is not NULL; otherwise, it returns value3.
+ pub fn nvl2(arg1: Expr, arg2: Expr, arg3: Expr) -> Expr {
+ super::nvl2().call(vec![arg1, arg2, arg3])
+ }
+
+ /// Returns the Arrow type of the input expression.
+ pub fn arrow_typeof(arg1: Expr) -> Expr {
+ super::arrow_typeof().call(vec![arg1])
+ }
+
+ /// Returns a struct with the given arguments
+ pub fn r#struct(args: Vec<Expr>) -> Expr {
+ super::r#struct().call(args)
+ }
+
+ /// Returns a struct with the given names and arguments pairs
+ pub fn named_struct(args: Vec<Expr>) -> Expr {
+ super::named_struct().call(args)
+ }
+
+ /// Returns the value of the field with the given name from the struct
+ pub fn get_field(arg1: Expr, arg2: Expr) -> Expr {
+ super::get_field().call(vec![arg1, arg2])
+ }
+
+ /// Returns `coalesce(args...)`, which evaluates to the value of the first
expr which is not NULL
+ pub fn coalesce(args: Vec<Expr>) -> Expr {
+ super::coalesce().call(args)
+ }
+}
+
+/// Return a list of all functions in this package
+pub fn functions() -> Vec<Arc<ScalarUDF>> {
+ vec![
+ nullif(),
+ arrow_cast(),
+ nvl(),
+ nvl2(),
+ arrow_typeof(),
+ r#struct(),
+ named_struct(),
+ get_field(),
+ coalesce(),
+ ]
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]