This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 7873e5c249 Add union_extract scalar function (#12116)
7873e5c249 is described below

commit 7873e5c24922d27981b06e0dbbb95e78a5e659a1
Author: gstvg <28798827+gs...@users.noreply.github.com>
AuthorDate: Fri Feb 14 11:36:49 2025 -0300

    Add union_extract scalar function (#12116)
    
    * feat: add union_extract scalar function
    
    * fix:  docs fmt, add clippy atr, sql error msg
    
    * use arrow-rs implementation
    
    * docs: add union functions section
    
    * docs: simplify union_extract docs
    
    * test: simplify union_extract sqllogictests
    
    * refactor(union_extract): new udf api, docs macro, use any signature
    
    * fix: remove user_doc include attribute
    
    * fix: generate docs
    
    * fix: manually trim sqllogictest generated errors
    
    * fix: fmt
    
    * docs: add union functions section description
    
    * docs: update functions docs
    
    * docs: clarify union_extract description
    
    Co-authored-by: Bruce Ritchie <bruce.ritc...@veeva.com>
    
    * fix: use return_type_from_args, tests, docs
    
    ---------
    
    Co-authored-by: Bruce Ritchie <bruce.ritc...@veeva.com>
---
 datafusion/expr/src/udf.rs                         |   8 +
 datafusion/functions/src/core/mod.rs               |   8 +
 datafusion/functions/src/core/union_extract.rs     | 255 +++++++++++++++++++++
 datafusion/sqllogictest/src/test_context.rs        |  32 ++-
 .../sqllogictest/test_files/union_function.slt     |  47 ++++
 docs/source/user-guide/sql/scalar_functions.md     |  34 +++
 6 files changed, 381 insertions(+), 3 deletions(-)

diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs
index b41d975203..54a692f2f3 100644
--- a/datafusion/expr/src/udf.rs
+++ b/datafusion/expr/src/udf.rs
@@ -980,6 +980,7 @@ pub mod scalar_doc_sections {
             DOC_SECTION_STRUCT,
             DOC_SECTION_MAP,
             DOC_SECTION_HASHING,
+            DOC_SECTION_UNION,
             DOC_SECTION_OTHER,
         ]
     }
@@ -996,6 +997,7 @@ pub mod scalar_doc_sections {
             DOC_SECTION_STRUCT,
             DOC_SECTION_MAP,
             DOC_SECTION_HASHING,
+            DOC_SECTION_UNION,
             DOC_SECTION_OTHER,
         ]
     }
@@ -1070,4 +1072,10 @@ The following regular expression functions are 
supported:"#,
         label: "Other Functions",
         description: None,
     };
+
+    pub const DOC_SECTION_UNION: DocSection = DocSection {
+        include: true,
+        label: "Union Functions",
+        description: Some("Functions to work with the union data type, also 
know as tagged unions, variant types, enums or sum types. Note: Not related to 
the SQL UNION operator"),
+    };
 }
diff --git a/datafusion/functions/src/core/mod.rs 
b/datafusion/functions/src/core/mod.rs
index 76fb4bbe5b..425ce78dec 100644
--- a/datafusion/functions/src/core/mod.rs
+++ b/datafusion/functions/src/core/mod.rs
@@ -34,6 +34,7 @@ pub mod nvl;
 pub mod nvl2;
 pub mod planner;
 pub mod r#struct;
+pub mod union_extract;
 pub mod version;
 
 // create UDFs
@@ -48,6 +49,7 @@ make_udf_function!(getfield::GetFieldFunc, get_field);
 make_udf_function!(coalesce::CoalesceFunc, coalesce);
 make_udf_function!(greatest::GreatestFunc, greatest);
 make_udf_function!(least::LeastFunc, least);
+make_udf_function!(union_extract::UnionExtractFun, union_extract);
 make_udf_function!(version::VersionFunc, version);
 
 pub mod expr_fn {
@@ -99,6 +101,11 @@ pub mod expr_fn {
     pub fn get_field(arg1: Expr, arg2: impl Literal) -> Expr {
         super::get_field().call(vec![arg1, arg2.lit()])
     }
+
+    #[doc = "Returns the value of the field with the given name from the union 
when it's selected, or NULL otherwise"]
+    pub fn union_extract(arg1: Expr, arg2: impl Literal) -> Expr {
+        super::union_extract().call(vec![arg1, arg2.lit()])
+    }
 }
 
 /// Returns all DataFusion functions defined in this package
@@ -121,6 +128,7 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
         coalesce(),
         greatest(),
         least(),
+        union_extract(),
         version(),
         r#struct(),
     ]
diff --git a/datafusion/functions/src/core/union_extract.rs 
b/datafusion/functions/src/core/union_extract.rs
new file mode 100644
index 0000000000..d54627f735
--- /dev/null
+++ b/datafusion/functions/src/core/union_extract.rs
@@ -0,0 +1,255 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::Array;
+use arrow::datatypes::{DataType, FieldRef, UnionFields};
+use datafusion_common::cast::as_union_array;
+use datafusion_common::{
+    exec_datafusion_err, exec_err, internal_err, Result, ScalarValue,
+};
+use datafusion_doc::Documentation;
+use datafusion_expr::{ColumnarValue, ReturnInfo, ReturnTypeArgs, 
ScalarFunctionArgs};
+use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
+use datafusion_macros::user_doc;
+
+#[user_doc(
+    doc_section(label = "Union Functions"),
+    description = "Returns the value of the given field in the union when 
selected, or NULL otherwise.",
+    syntax_example = "union_extract(union, field_name)",
+    sql_example = r#"```sql
+❯ select union_column, union_extract(union_column, 'a'), 
union_extract(union_column, 'b') from table_with_union;
++--------------+----------------------------------+----------------------------------+
+| union_column | union_extract(union_column, 'a') | 
union_extract(union_column, 'b') |
++--------------+----------------------------------+----------------------------------+
+| {a=1}        | 1                                |                            
      |
+| {b=3.0}      |                                  | 3.0                        
      |
+| {a=4}        | 4                                |                            
      |
+| {b=}         |                                  |                            
      |
+| {a=}         |                                  |                            
      |
++--------------+----------------------------------+----------------------------------+
+```"#,
+    standard_argument(name = "union", prefix = "Union"),
+    argument(
+        name = "field_name",
+        description = "String expression to operate on. Must be a constant."
+    )
+)]
+#[derive(Debug)]
+pub struct UnionExtractFun {
+    signature: Signature,
+}
+
+impl Default for UnionExtractFun {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl UnionExtractFun {
+    pub fn new() -> Self {
+        Self {
+            signature: Signature::any(2, Volatility::Immutable),
+        }
+    }
+}
+
+impl ScalarUDFImpl for UnionExtractFun {
+    fn as_any(&self) -> &dyn std::any::Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        "union_extract"
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
+        // should be using return_type_from_exprs and not calling the default 
implementation
+        internal_err!("union_extract should return type from exprs")
+    }
+
+    fn return_type_from_args(&self, args: ReturnTypeArgs) -> 
Result<ReturnInfo> {
+        if args.arg_types.len() != 2 {
+            return exec_err!(
+                "union_extract expects 2 arguments, got {} instead",
+                args.arg_types.len()
+            );
+        }
+
+        let DataType::Union(fields, _) = &args.arg_types[0] else {
+            return exec_err!(
+                "union_extract first argument must be a union, got {} instead",
+                args.arg_types[0]
+            );
+        };
+
+        let Some(ScalarValue::Utf8(Some(field_name))) = 
&args.scalar_arguments[1] else {
+            return exec_err!(
+                "union_extract second argument must be a non-null string 
literal, got {} instead",
+                args.arg_types[1]
+            );
+        };
+
+        let field = find_field(fields, field_name)?.1;
+
+        Ok(ReturnInfo::new_nullable(field.data_type().clone()))
+    }
+
+    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
+        let args = args.args;
+
+        if args.len() != 2 {
+            return exec_err!(
+                "union_extract expects 2 arguments, got {} instead",
+                args.len()
+            );
+        }
+
+        let target_name = match &args[1] {
+            ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_name))) => 
Ok(target_name),
+            ColumnarValue::Scalar(ScalarValue::Utf8(None)) => 
exec_err!("union_extract second argument must be a non-null string literal, got 
a null instead"),
+            _ => exec_err!("union_extract second argument must be a non-null 
string literal, got {} instead", &args[1].data_type()),
+        };
+
+        match &args[0] {
+            ColumnarValue::Array(array) => {
+                let union_array = as_union_array(&array).map_err(|_| {
+                    exec_datafusion_err!(
+                        "union_extract first argument must be a union, got {} 
instead",
+                        array.data_type()
+                    )
+                })?;
+
+                Ok(ColumnarValue::Array(
+                    arrow::compute::kernels::union_extract::union_extract(
+                        union_array,
+                        target_name?,
+                    )?,
+                ))
+            }
+            ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => {
+                let target_name = target_name?;
+                let (target_type_id, target) = find_field(fields, 
target_name)?;
+
+                let result = match value {
+                    Some((type_id, value)) if target_type_id == *type_id => {
+                        *value.clone()
+                    }
+                    _ => ScalarValue::try_from(target.data_type())?,
+                };
+
+                Ok(ColumnarValue::Scalar(result))
+            }
+            other => exec_err!(
+                "union_extract first argument must be a union, got {} instead",
+                other.data_type()
+            ),
+        }
+    }
+
+    fn documentation(&self) -> Option<&Documentation> {
+        self.doc()
+    }
+}
+
+fn find_field<'a>(fields: &'a UnionFields, name: &str) -> Result<(i8, &'a 
FieldRef)> {
+    fields
+        .iter()
+        .find(|field| field.1.name() == name)
+        .ok_or_else(|| exec_datafusion_err!("field {name} not found on union"))
+}
+
+#[cfg(test)]
+mod tests {
+
+    use arrow::datatypes::{DataType, Field, UnionFields, UnionMode};
+    use datafusion_common::{Result, ScalarValue};
+    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
+
+    use super::UnionExtractFun;
+
+    // when it becomes possible to construct union scalars in SQL, this should 
go to sqllogictests
+    #[test]
+    fn test_scalar_value() -> Result<()> {
+        let fun = UnionExtractFun::new();
+
+        let fields = UnionFields::new(
+            vec![1, 3],
+            vec![
+                Field::new("str", DataType::Utf8, false),
+                Field::new("int", DataType::Int32, false),
+            ],
+        );
+
+        let result = fun.invoke_with_args(ScalarFunctionArgs {
+            args: vec![
+                ColumnarValue::Scalar(ScalarValue::Union(
+                    None,
+                    fields.clone(),
+                    UnionMode::Dense,
+                )),
+                ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
+            ],
+            number_rows: 1,
+            return_type: &DataType::Utf8,
+        })?;
+
+        assert_scalar(result, ScalarValue::Utf8(None));
+
+        let result = fun.invoke_with_args(ScalarFunctionArgs {
+            args: vec![
+                ColumnarValue::Scalar(ScalarValue::Union(
+                    Some((3, Box::new(ScalarValue::Int32(Some(42))))),
+                    fields.clone(),
+                    UnionMode::Dense,
+                )),
+                ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
+            ],
+            number_rows: 1,
+            return_type: &DataType::Utf8,
+        })?;
+
+        assert_scalar(result, ScalarValue::Utf8(None));
+
+        let result = fun.invoke_with_args(ScalarFunctionArgs {
+            args: vec![
+                ColumnarValue::Scalar(ScalarValue::Union(
+                    Some((1, Box::new(ScalarValue::new_utf8("42")))),
+                    fields.clone(),
+                    UnionMode::Dense,
+                )),
+                ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
+            ],
+            number_rows: 1,
+            return_type: &DataType::Utf8,
+        })?;
+
+        assert_scalar(result, ScalarValue::new_utf8("42"));
+
+        Ok(())
+    }
+
+    fn assert_scalar(value: ColumnarValue, expected: ScalarValue) {
+        match value {
+            ColumnarValue::Array(array) => panic!("expected scalar got 
{array:?}"),
+            ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected),
+        }
+    }
+}
diff --git a/datafusion/sqllogictest/src/test_context.rs 
b/datafusion/sqllogictest/src/test_context.rs
index 31b1e1e8a1..ce819f1864 100644
--- a/datafusion/sqllogictest/src/test_context.rs
+++ b/datafusion/sqllogictest/src/test_context.rs
@@ -22,10 +22,11 @@ use std::path::Path;
 use std::sync::Arc;
 
 use arrow::array::{
-    ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, 
LargeStringArray,
-    StringArray, TimestampNanosecondArray,
+    Array, ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray,
+    LargeStringArray, StringArray, TimestampNanosecondArray, UnionArray,
 };
-use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
+use arrow::buffer::ScalarBuffer;
+use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit, 
UnionFields};
 use arrow::record_batch::RecordBatch;
 use datafusion::catalog::{
     CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, Session,
@@ -113,6 +114,10 @@ impl TestContext {
                 info!("Registering metadata table tables");
                 register_metadata_tables(test_ctx.session_ctx()).await;
             }
+            "union_function.slt" => {
+                info!("Registering table with union column");
+                register_union_table(test_ctx.session_ctx())
+            }
             _ => {
                 info!("Using default SessionContext");
             }
@@ -402,3 +407,24 @@ fn create_example_udf() -> ScalarUDF {
         adder,
     )
 }
+
+fn register_union_table(ctx: &SessionContext) {
+    let union = UnionArray::try_new(
+        UnionFields::new(vec![3], vec![Field::new("int", DataType::Int32, 
false)]),
+        ScalarBuffer::from(vec![3, 3]),
+        None,
+        vec![Arc::new(Int32Array::from(vec![1, 2]))],
+    )
+    .unwrap();
+
+    let schema = Schema::new(vec![Field::new(
+        "union_column",
+        union.data_type().clone(),
+        false,
+    )]);
+
+    let batch =
+        RecordBatch::try_new(Arc::new(schema.clone()), 
vec![Arc::new(union)]).unwrap();
+
+    ctx.register_batch("union_table", batch).unwrap();
+}
diff --git a/datafusion/sqllogictest/test_files/union_function.slt 
b/datafusion/sqllogictest/test_files/union_function.slt
new file mode 100644
index 0000000000..9c70b1011f
--- /dev/null
+++ b/datafusion/sqllogictest/test_files/union_function.slt
@@ -0,0 +1,47 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+
+#   http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+##########
+## UNION DataType Tests
+##########
+
+query ?I
+select union_column, union_extract(union_column, 'int') from union_table;
+----
+{int=1} 1
+{int=2} 2
+
+query error DataFusion error: Execution error: field bool not found on union
+select union_extract(union_column, 'bool') from union_table;
+
+query error DataFusion error: Error during planning: 'union_extract' does not 
support zero arguments
+select union_extract() from union_table;
+
+query error DataFusion error: Error during planning: The function 
'union_extract' expected 2 arguments but received 1
+select union_extract(union_column) from union_table;
+
+query error DataFusion error: Error during planning: The function 
'union_extract' expected 2 arguments but received 1
+select union_extract('a') from union_table;
+
+query error DataFusion error: Execution error: union_extract first argument 
must be a union, got Utf8 instead
+select union_extract('a', union_column) from union_table;
+
+query error DataFusion error: Execution error: union_extract second argument 
must be a non\-null string literal, got Int64 instead
+select union_extract(union_column, 1) from union_table;
+
+query error DataFusion error: Error during planning: The function 
'union_extract' expected 2 arguments but received 3
+select union_extract(union_column, 'a', 'b') from union_table;
diff --git a/docs/source/user-guide/sql/scalar_functions.md 
b/docs/source/user-guide/sql/scalar_functions.md
index b14bf5b2cc..25ff296879 100644
--- a/docs/source/user-guide/sql/scalar_functions.md
+++ b/docs/source/user-guide/sql/scalar_functions.md
@@ -4339,6 +4339,40 @@ sha512(expression)
 +-------------------------------------------+
 ```
 
+## Union Functions
+
+Functions to work with the union data type, also know as tagged unions, 
variant types, enums or sum types. Note: Not related to the SQL UNION operator
+
+- [union_extract](#union_extract)
+
+### `union_extract`
+
+Returns the value of the given field in the union when selected, or NULL 
otherwise.
+
+```
+union_extract(union, field_name)
+```
+
+#### Arguments
+
+- **union**: Union expression to operate on. Can be a constant, column, or 
function, and any combination of operators.
+- **field_name**: String expression to operate on. Must be a constant.
+
+#### Example
+
+```sql
+❯ select union_column, union_extract(union_column, 'a'), 
union_extract(union_column, 'b') from table_with_union;
++--------------+----------------------------------+----------------------------------+
+| union_column | union_extract(union_column, 'a') | 
union_extract(union_column, 'b') |
++--------------+----------------------------------+----------------------------------+
+| {a=1}        | 1                                |                            
      |
+| {b=3.0}      |                                  | 3.0                        
      |
+| {a=4}        | 4                                |                            
      |
+| {b=}         |                                  |                            
      |
+| {a=}         |                                  |                            
      |
++--------------+----------------------------------+----------------------------------+
+```
+
 ## Other Functions
 
 - [arrow_cast](#arrow_cast)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to