alamb commented on code in PR #15646:
URL: https://github.com/apache/datafusion/pull/15646#discussion_r2051449542


##########
datafusion/common/src/dfschema.rs:
##########
@@ -969,16 +969,28 @@ impl Display for DFSchema {
 /// widely used in the DataFusion codebase.
 pub trait ExprSchema: std::fmt::Debug {
     /// Is this column reference nullable?
-    fn nullable(&self, col: &Column) -> Result<bool>;
+    fn nullable(&self, col: &Column) -> Result<bool> {

Review Comment:
   It seems like we could (perhaps as a follow on ticket) deprecate all the 
other methods on `ExprSchema` as `to_field` supercedes all of them. It would 
make a good first issue ticket



##########
datafusion/expr/src/expr_schema.rs:
##########
@@ -762,29 +806,25 @@ mod tests {
 
     #[derive(Debug)]
     struct MockExprSchema {
-        nullable: bool,
-        data_type: DataType,
+        field: Field,

Review Comment:
   the fact this is simpler but now also more full featured suggests to me we 
are on the right track



##########
datafusion/core/tests/user_defined/user_defined_scalar_functions.rs:
##########
@@ -803,17 +806,17 @@ impl ScalarUDFImpl for TakeUDF {
         &self.signature
     }
     fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
-        not_impl_err!("Not called because the return_type_from_args is 
implemented")
+        not_impl_err!("Not called because the return_field_from_args is 
implemented")
     }
 
     /// This function returns the type of the first or second argument based on
     /// the third argument:
     ///
     /// 1. If the third argument is '0', return the type of the first argument
     /// 2. If the third argument is '1', return the type of the second argument
-    fn return_type_from_args(&self, args: ReturnTypeArgs) -> 
Result<ReturnInfo> {
-        if args.arg_types.len() != 3 {
-            return plan_err!("Expected 3 arguments, got {}.", 
args.arg_types.len());
+    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<Field> {

Review Comment:
   This is the core change, as I understand. 



##########
datafusion/expr/src/udf.rs:
##########
@@ -719,6 +682,12 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
     fn documentation(&self) -> Option<&Documentation> {
         None
     }
+
+    /// This describes the output field associated with this UDF.
+    /// Input field is handled through `ScalarFunctionArgs`
+    fn output_field(&self, _input_schema: &Schema) -> Option<Field> {
+        None

Review Comment:
   Doesn't this function have to be consistent with `return_field_from_args`? 
It seems like if we leave it as `None` it means many function implementation 
will not implement it and thus it will be inconsistent
   
   Also, can you please document what `Non` vs `Some` means (it isn't clear 
what returning `None` means here -- does it mean the Field information is not 
known? If so, what are the implications of that?)



##########
datafusion/common/src/dfschema.rs:
##########
@@ -969,16 +969,28 @@ impl Display for DFSchema {
 /// widely used in the DataFusion codebase.
 pub trait ExprSchema: std::fmt::Debug {
     /// Is this column reference nullable?
-    fn nullable(&self, col: &Column) -> Result<bool>;
+    fn nullable(&self, col: &Column) -> Result<bool> {
+        Ok(self.to_field(col)?.is_nullable())
+    }
 
     /// What is the datatype of this column?
-    fn data_type(&self, col: &Column) -> Result<&DataType>;
+    fn data_type(&self, col: &Column) -> Result<&DataType> {
+        Ok(self.to_field(col)?.data_type())
+    }
 
     /// Returns the column's optional metadata.
-    fn metadata(&self, col: &Column) -> Result<&HashMap<String, String>>;
+    fn metadata(&self, col: &Column) -> Result<&HashMap<String, String>> {
+        Ok(self.to_field(col)?.metadata())
+    }
 
     /// Return the column's datatype and nullability
-    fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, 
bool)>;
+    fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, 
bool)> {
+        let field = self.to_field(col)?;
+        Ok((field.data_type(), field.is_nullable()))
+    }
+
+    // Return the column's field
+    fn to_field(&self, col: &Column) -> Result<&Field>;

Review Comment:
   I would personally find `field()` (rather than `to_field`) clearer here as I 
normally think of methods like `to_` doing some sort of owned concersion where 
as this is just accessing a field



##########
datafusion/expr/src/udf.rs:
##########
@@ -293,14 +293,17 @@ where
 
 /// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a
 /// scalar function.
-pub struct ScalarFunctionArgs<'a> {
+pub struct ScalarFunctionArgs<'a, 'b> {
     /// The evaluated arguments to the function
     pub args: Vec<ColumnarValue>,
+    /// Field associated with each arg, if it exists

Review Comment:
   very minor but i found it strange that `a` was used second (as in I would 
expect 
   ```rust
       pub arg_fields: Vec<Option<&'a Field>>,
   ...
       pub return_field: &'b Field,
   ```



##########
datafusion/expr/src/udf.rs:
##########
@@ -309,64 +312,18 @@ pub struct ScalarFunctionArgs<'a> {
 /// such as the type of the arguments, any scalar arguments and if the
 /// arguments can (ever) be null
 ///
-/// See [`ScalarUDFImpl::return_type_from_args`] for more information
+/// See [`ScalarUDFImpl::return_field_from_args`] for more information
 #[derive(Debug)]
-pub struct ReturnTypeArgs<'a> {
+pub struct ReturnFieldArgs<'a> {
     /// The data types of the arguments to the function
-    pub arg_types: &'a [DataType],
+    pub arg_fields: &'a [Field],
     /// Is argument `i` to the function a scalar (constant)
     ///
     /// If argument `i` is not a scalar, it will be None
     ///
     /// For example, if a function is called like `my_function(column_a, 5)`
     /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]`
     pub scalar_arguments: &'a [Option<&'a ScalarValue>],
-    /// Can argument `i` (ever) null?
-    pub nullables: &'a [bool],
-}
-
-/// Return metadata for this function.
-///
-/// See [`ScalarUDFImpl::return_type_from_args`] for more information
-#[derive(Debug)]
-pub struct ReturnInfo {

Review Comment:
   It is sort of interesting to see that DataFUsion reinvented Field in several 
places (ReturnInfo, ExprSchamable, etc)



##########
datafusion/core/tests/user_defined/user_defined_scalar_functions.rs:
##########
@@ -1367,3 +1370,346 @@ async fn register_alltypes_parquet(ctx: 
&SessionContext) -> Result<()> {
 async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> 
Result<Vec<RecordBatch>> {
     ctx.sql(sql).await?.collect().await
 }
+
+#[derive(Debug)]
+struct MetadataBasedUdf {
+    name: String,
+    signature: Signature,
+    output_field: Field,
+}
+
+impl MetadataBasedUdf {
+    fn new(metadata: HashMap<String, String>) -> Self {
+        // The name we return must be unique. Otherwise we will not call 
distinct
+        // instances of this UDF. This is a small hack for the unit tests to 
get unique
+        // names, but you could do something more elegant with the metadata.
+        let name = format!("metadata_based_udf_{}", metadata.len());
+        let output_field =
+            Field::new(&name, DataType::UInt64, true).with_metadata(metadata);
+        Self {
+            name,
+            signature: Signature::exact(vec![DataType::UInt64], 
Volatility::Immutable),
+            output_field,
+        }
+    }
+}
+
+impl ScalarUDFImpl for MetadataBasedUdf {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        &self.name
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
+        Ok(DataType::UInt64)
+    }
+
+    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
+        assert_eq!(args.arg_fields.len(), 1);
+        let should_double = match &args.arg_fields[0] {
+            Some(field) => field
+                .metadata()
+                .get("modify_values")
+                .map(|v| v == "double_output")
+                .unwrap_or(false),
+            None => false,
+        };
+        let mulitplier = if should_double { 2 } else { 1 };
+
+        match &args.args[0] {
+            ColumnarValue::Array(array) => {
+                let array_values: Vec<_> = array
+                    .as_any()
+                    .downcast_ref::<UInt64Array>()
+                    .unwrap()
+                    .iter()
+                    .map(|v| v.map(|x| x * mulitplier))
+                    .collect();
+                let array_ref = Arc::new(UInt64Array::from(array_values)) as 
ArrayRef;
+                Ok(ColumnarValue::Array(array_ref))
+            }
+            ColumnarValue::Scalar(value) => {
+                let ScalarValue::UInt64(value) = value else {
+                    return exec_err!("incorrect data type");
+                };
+
+                Ok(ColumnarValue::Scalar(ScalarValue::UInt64(
+                    value.map(|v| v * mulitplier),
+                )))
+            }
+        }
+    }
+
+    fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
+        self.name == other.name()
+    }
+
+    fn output_field(&self, _input_schema: &Schema) -> Option<Field> {
+        Some(self.output_field.clone())
+    }
+}
+
+#[tokio::test]
+async fn test_metadata_based_udf() -> Result<()> {
+    let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as 
ArrayRef;
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("no_metadata", DataType::UInt64, true),
+        Field::new("with_metadata", DataType::UInt64, true).with_metadata(
+            [("modify_values".to_string(), "double_output".to_string())]
+                .into_iter()
+                .collect(),
+        ),
+    ]));
+    let batch = RecordBatch::try_new(
+        schema,
+        vec![Arc::clone(&data_array), Arc::clone(&data_array)],
+    )?;
+
+    let ctx = SessionContext::new();
+    ctx.register_batch("t", batch)?;
+    let t = ctx.table("t").await?;
+    let no_output_meta_udf = 
ScalarUDF::from(MetadataBasedUdf::new(HashMap::new()));
+    let with_output_meta_udf = ScalarUDF::from(MetadataBasedUdf::new(
+        [("output_metatype".to_string(), "custom_value".to_string())]
+            .into_iter()
+            .collect(),
+    ));
+
+    let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?)
+        .project(vec![
+            no_output_meta_udf
+                .call(vec![col("no_metadata")])
+                .alias("meta_no_in_no_out"),
+            no_output_meta_udf
+                .call(vec![col("with_metadata")])
+                .alias("meta_with_in_no_out"),
+            with_output_meta_udf
+                .call(vec![col("no_metadata")])
+                .alias("meta_no_in_with_out"),
+            with_output_meta_udf
+                .call(vec![col("with_metadata")])
+                .alias("meta_with_in_with_out"),
+        ])?
+        .build()?;
+
+    let actual = DataFrame::new(ctx.state(), plan).collect().await?;
+
+    // To test for output metadata handling, we set the expected values on the 
result
+    // To test for input metadata handling, we check the numbers returned
+    let mut output_meta = HashMap::new();
+    let _ = output_meta.insert("output_metatype".to_string(), 
"custom_value".to_string());
+    let expected_schema = Schema::new(vec![
+        Field::new("meta_no_in_no_out", DataType::UInt64, true),
+        Field::new("meta_with_in_no_out", DataType::UInt64, true),
+        Field::new("meta_no_in_with_out", DataType::UInt64, true)
+            .with_metadata(output_meta.clone()),
+        Field::new("meta_with_in_with_out", DataType::UInt64, true)
+            .with_metadata(output_meta.clone()),
+    ]);
+
+    let expected = record_batch!(
+        ("meta_no_in_no_out", UInt64, [0, 5, 10, 15, 20]),
+        ("meta_with_in_no_out", UInt64, [0, 10, 20, 30, 40]),
+        ("meta_no_in_with_out", UInt64, [0, 5, 10, 15, 20]),
+        ("meta_with_in_with_out", UInt64, [0, 10, 20, 30, 40])
+    )?
+    .with_schema(Arc::new(expected_schema))?;
+
+    assert_eq!(expected, actual[0]);
+
+    ctx.deregister_table("t")?;
+    Ok(())
+}
+
+/// This UDF is to test extension handling, both on the input and output
+/// sides. For the input, we will handle the data differently if there is
+/// the canonical extension type Bool8. For the output we will add a
+/// user defined extension type.
+#[derive(Debug)]
+struct ExtensionBasedUdf {
+    name: String,
+    signature: Signature,
+}
+
+impl Default for ExtensionBasedUdf {
+    fn default() -> Self {
+        Self {
+            name: "canonical_extension_udf".to_string(),
+            signature: Signature::exact(vec![DataType::Int8], 
Volatility::Immutable),
+        }
+    }
+}
+impl ScalarUDFImpl for ExtensionBasedUdf {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        &self.name
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
+        Ok(DataType::Utf8)
+    }
+
+    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
+        assert_eq!(args.arg_fields.len(), 1);
+        let input_field = args.arg_fields[0].unwrap();
+
+        let output_as_bool = matches!(
+            CanonicalExtensionType::try_from(input_field),

Review Comment:
   I think this particular function  `invoke_with_args`, gets run during the 
execuiton phase on each batch
   
   



##########
datafusion/core/tests/user_defined/user_defined_scalar_functions.rs:
##########
@@ -1367,3 +1370,346 @@ async fn register_alltypes_parquet(ctx: 
&SessionContext) -> Result<()> {
 async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> 
Result<Vec<RecordBatch>> {
     ctx.sql(sql).await?.collect().await
 }
+
+#[derive(Debug)]
+struct MetadataBasedUdf {
+    name: String,
+    signature: Signature,
+    output_field: Field,
+}
+
+impl MetadataBasedUdf {
+    fn new(metadata: HashMap<String, String>) -> Self {
+        // The name we return must be unique. Otherwise we will not call 
distinct
+        // instances of this UDF. This is a small hack for the unit tests to 
get unique
+        // names, but you could do something more elegant with the metadata.
+        let name = format!("metadata_based_udf_{}", metadata.len());
+        let output_field =
+            Field::new(&name, DataType::UInt64, true).with_metadata(metadata);
+        Self {
+            name,
+            signature: Signature::exact(vec![DataType::UInt64], 
Volatility::Immutable),
+            output_field,
+        }
+    }
+}
+
+impl ScalarUDFImpl for MetadataBasedUdf {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        &self.name
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
+        Ok(DataType::UInt64)
+    }
+
+    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
+        assert_eq!(args.arg_fields.len(), 1);
+        let should_double = match &args.arg_fields[0] {
+            Some(field) => field
+                .metadata()
+                .get("modify_values")
+                .map(|v| v == "double_output")
+                .unwrap_or(false),
+            None => false,
+        };
+        let mulitplier = if should_double { 2 } else { 1 };
+
+        match &args.args[0] {
+            ColumnarValue::Array(array) => {
+                let array_values: Vec<_> = array
+                    .as_any()
+                    .downcast_ref::<UInt64Array>()
+                    .unwrap()
+                    .iter()
+                    .map(|v| v.map(|x| x * mulitplier))
+                    .collect();
+                let array_ref = Arc::new(UInt64Array::from(array_values)) as 
ArrayRef;
+                Ok(ColumnarValue::Array(array_ref))
+            }
+            ColumnarValue::Scalar(value) => {
+                let ScalarValue::UInt64(value) = value else {
+                    return exec_err!("incorrect data type");
+                };
+
+                Ok(ColumnarValue::Scalar(ScalarValue::UInt64(
+                    value.map(|v| v * mulitplier),
+                )))
+            }
+        }
+    }
+
+    fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
+        self.name == other.name()
+    }
+
+    fn output_field(&self, _input_schema: &Schema) -> Option<Field> {
+        Some(self.output_field.clone())
+    }
+}
+
+#[tokio::test]
+async fn test_metadata_based_udf() -> Result<()> {
+    let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as 
ArrayRef;
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("no_metadata", DataType::UInt64, true),
+        Field::new("with_metadata", DataType::UInt64, true).with_metadata(
+            [("modify_values".to_string(), "double_output".to_string())]
+                .into_iter()
+                .collect(),
+        ),
+    ]));
+    let batch = RecordBatch::try_new(
+        schema,
+        vec![Arc::clone(&data_array), Arc::clone(&data_array)],
+    )?;
+
+    let ctx = SessionContext::new();
+    ctx.register_batch("t", batch)?;
+    let t = ctx.table("t").await?;
+    let no_output_meta_udf = 
ScalarUDF::from(MetadataBasedUdf::new(HashMap::new()));
+    let with_output_meta_udf = ScalarUDF::from(MetadataBasedUdf::new(
+        [("output_metatype".to_string(), "custom_value".to_string())]
+            .into_iter()
+            .collect(),
+    ));
+
+    let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?)
+        .project(vec![
+            no_output_meta_udf
+                .call(vec![col("no_metadata")])
+                .alias("meta_no_in_no_out"),
+            no_output_meta_udf
+                .call(vec![col("with_metadata")])
+                .alias("meta_with_in_no_out"),
+            with_output_meta_udf
+                .call(vec![col("no_metadata")])
+                .alias("meta_no_in_with_out"),
+            with_output_meta_udf
+                .call(vec![col("with_metadata")])
+                .alias("meta_with_in_with_out"),
+        ])?
+        .build()?;
+
+    let actual = DataFrame::new(ctx.state(), plan).collect().await?;
+
+    // To test for output metadata handling, we set the expected values on the 
result
+    // To test for input metadata handling, we check the numbers returned
+    let mut output_meta = HashMap::new();
+    let _ = output_meta.insert("output_metatype".to_string(), 
"custom_value".to_string());
+    let expected_schema = Schema::new(vec![
+        Field::new("meta_no_in_no_out", DataType::UInt64, true),
+        Field::new("meta_with_in_no_out", DataType::UInt64, true),
+        Field::new("meta_no_in_with_out", DataType::UInt64, true)
+            .with_metadata(output_meta.clone()),
+        Field::new("meta_with_in_with_out", DataType::UInt64, true)
+            .with_metadata(output_meta.clone()),
+    ]);
+
+    let expected = record_batch!(
+        ("meta_no_in_no_out", UInt64, [0, 5, 10, 15, 20]),
+        ("meta_with_in_no_out", UInt64, [0, 10, 20, 30, 40]),
+        ("meta_no_in_with_out", UInt64, [0, 5, 10, 15, 20]),
+        ("meta_with_in_with_out", UInt64, [0, 10, 20, 30, 40])
+    )?
+    .with_schema(Arc::new(expected_schema))?;
+
+    assert_eq!(expected, actual[0]);
+
+    ctx.deregister_table("t")?;
+    Ok(())
+}
+
+/// This UDF is to test extension handling, both on the input and output
+/// sides. For the input, we will handle the data differently if there is
+/// the canonical extension type Bool8. For the output we will add a
+/// user defined extension type.
+#[derive(Debug)]
+struct ExtensionBasedUdf {
+    name: String,
+    signature: Signature,
+}
+
+impl Default for ExtensionBasedUdf {
+    fn default() -> Self {
+        Self {
+            name: "canonical_extension_udf".to_string(),
+            signature: Signature::exact(vec![DataType::Int8], 
Volatility::Immutable),
+        }
+    }
+}
+impl ScalarUDFImpl for ExtensionBasedUdf {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        &self.name
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
+        Ok(DataType::Utf8)
+    }
+
+    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
+        assert_eq!(args.arg_fields.len(), 1);
+        let input_field = args.arg_fields[0].unwrap();
+
+        let output_as_bool = matches!(
+            CanonicalExtensionType::try_from(input_field),
+            Ok(CanonicalExtensionType::Bool8(_))
+        );
+
+        // If we have the extension type set, we are outputting a boolean 
value.
+        // Otherwise we output a string representation of the numeric value.
+        fn print_value(v: Option<i8>, as_bool: bool) -> Option<String> {
+            v.map(|x| match as_bool {
+                true => format!("{}", x != 0),
+                false => format!("{x}"),
+            })
+        }
+
+        match &args.args[0] {
+            ColumnarValue::Array(array) => {
+                let array_values: Vec<_> = array
+                    .as_any()
+                    .downcast_ref::<Int8Array>()
+                    .unwrap()
+                    .iter()
+                    .map(|v| print_value(v, output_as_bool))
+                    .collect();
+                let array_ref = Arc::new(StringArray::from(array_values)) as 
ArrayRef;
+                Ok(ColumnarValue::Array(array_ref))
+            }
+            ColumnarValue::Scalar(value) => {

Review Comment:
   It is a good call we'll need to add the metadata somewhere. I think we'll 
have to figure out the best place as a follow on PR (we can defer until we 
figure out the right place)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to