findepi commented on code in PR #13756:
URL: https://github.com/apache/datafusion/pull/13756#discussion_r1883943935
##########
datafusion/functions-nested/src/extract.rs:
##########
@@ -993,3 +993,84 @@ where
let data = mutable.freeze();
Ok(arrow::array::make_array(data))
}
+
+#[cfg(test)]
+mod tests {
+ use super::array_element_udf;
+ use arrow_schema::{DataType, Field};
+ use datafusion_common::{Column, DFSchema, ScalarValue};
+ use datafusion_expr::expr::ScalarFunction;
+ use datafusion_expr::{cast, Expr, ExprSchemable};
+ use std::collections::HashMap;
+
+ #[test]
+ fn test_array_element_return_type() {
+ let complex_type = DataType::FixedSizeList(
+ Field::new("some_arbitrary_test_field", DataType::Int32,
false).into(),
+ 13,
+ );
+ let array_type =
+ DataType::List(Field::new_list_field(complex_type.clone(),
true).into());
+ let index_type = DataType::Int64;
+
+ let schema = DFSchema::from_unqualified_fields(
+ vec![
+ Field::new("my_array", array_type.clone(), false),
+ Field::new("my_index", index_type.clone(), false),
+ ]
+ .into(),
+ HashMap::default(),
+ )
+ .unwrap();
+
+ let udf = array_element_udf();
+
+ // ScalarUDFImpl::return_type
+ assert_eq!(
+ udf.return_type(&[array_type.clone(), index_type.clone()])
+ .unwrap(),
+ complex_type
+ );
+
+ // ScalarUDFImpl::return_type_from_exprs with typed exprs
+ assert_eq!(
+ udf.return_type_from_exprs(
+ &[
+ cast(Expr::Literal(ScalarValue::Null), array_type.clone()),
+ cast(Expr::Literal(ScalarValue::Null), index_type.clone()),
+ ],
+ &schema,
+ &[array_type.clone(), index_type.clone()]
+ )
+ .unwrap(),
+ complex_type
+ );
+
+ // ScalarUDFImpl::return_type_from_exprs with exprs not carrying type
+ assert_eq!(
+ udf.return_type_from_exprs(
+ &[
+ Expr::Column(Column::new_unqualified("my_array")),
+ Expr::Column(Column::new_unqualified("my_index")),
+ ],
+ &schema,
+ &[array_type.clone(), index_type.clone()]
+ )
+ .unwrap(),
+ complex_type
+ );
+
+ // Via ExprSchemable::get_type (e.g. SimplifyInfo)
+ let udf_expr = Expr::ScalarFunction(ScalarFunction {
+ func: array_element_udf(),
+ args: vec![
+ Expr::Column(Column::new_unqualified("my_array")),
+ Expr::Column(Column::new_unqualified("my_index")),
+ ],
+ });
+ assert_eq!(
+ ExprSchemable::get_type(&udf_expr, &schema).unwrap(),
+ complex_type
+ );
Review Comment:
This didn't pass before the change. The assertions above did pass.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]