timsaucer commented on code in PR #15646: URL: https://github.com/apache/datafusion/pull/15646#discussion_r2050653344
########## datafusion/core/tests/user_defined/user_defined_scalar_functions.rs: ########## @@ -1367,3 +1370,346 @@ async fn register_alltypes_parquet(ctx: &SessionContext) -> Result<()> { async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordBatch>> { ctx.sql(sql).await?.collect().await } + +#[derive(Debug)] +struct MetadataBasedUdf { + name: String, + signature: Signature, + output_field: Field, +} + +impl MetadataBasedUdf { + fn new(metadata: HashMap<String, String>) -> Self { + // The name we return must be unique. Otherwise we will not call distinct + // instances of this UDF. This is a small hack for the unit tests to get unique + // names, but you could do something more elegant with the metadata. + let name = format!("metadata_based_udf_{}", metadata.len()); + let output_field = + Field::new(&name, DataType::UInt64, true).with_metadata(metadata); + Self { + name, + signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable), + output_field, + } + } +} + +impl ScalarUDFImpl for MetadataBasedUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _args: &[DataType]) -> Result<DataType> { + Ok(DataType::UInt64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { + assert_eq!(args.arg_fields.len(), 1); + let should_double = match &args.arg_fields[0] { + Some(field) => field + .metadata() + .get("modify_values") + .map(|v| v == "double_output") + .unwrap_or(false), + None => false, + }; + let mulitplier = if should_double { 2 } else { 1 }; + + match &args.args[0] { + ColumnarValue::Array(array) => { + let array_values: Vec<_> = array + .as_any() + .downcast_ref::<UInt64Array>() + .unwrap() + .iter() + .map(|v| v.map(|x| x * mulitplier)) + .collect(); + let array_ref = Arc::new(UInt64Array::from(array_values)) as ArrayRef; + Ok(ColumnarValue::Array(array_ref)) + } + ColumnarValue::Scalar(value) => { + let ScalarValue::UInt64(value) = value else { + return exec_err!("incorrect data type"); + }; + + Ok(ColumnarValue::Scalar(ScalarValue::UInt64( + value.map(|v| v * mulitplier), + ))) + } + } + } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + self.name == other.name() + } + + fn output_field(&self, _input_schema: &Schema) -> Option<Field> { + Some(self.output_field.clone()) + } +} + +#[tokio::test] +async fn test_metadata_based_udf() -> Result<()> { + let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_metadata", DataType::UInt64, true), + Field::new("with_metadata", DataType::UInt64, true).with_metadata( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), + ])); + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let no_output_meta_udf = ScalarUDF::from(MetadataBasedUdf::new(HashMap::new())); + let with_output_meta_udf = ScalarUDF::from(MetadataBasedUdf::new( + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(), + )); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .project(vec![ + no_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_no_out"), + no_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_no_out"), + with_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_with_out"), + with_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_with_out"), + ])? + .build()?; + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + + // To test for output metadata handling, we set the expected values on the result + // To test for input metadata handling, we check the numbers returned + let mut output_meta = HashMap::new(); + let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("meta_no_in_no_out", DataType::UInt64, true), + Field::new("meta_with_in_no_out", DataType::UInt64, true), + Field::new("meta_no_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + ]); + + let expected = record_batch!( + ("meta_no_in_no_out", UInt64, [0, 5, 10, 15, 20]), + ("meta_with_in_no_out", UInt64, [0, 10, 20, 30, 40]), + ("meta_no_in_with_out", UInt64, [0, 5, 10, 15, 20]), + ("meta_with_in_with_out", UInt64, [0, 10, 20, 30, 40]) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + ctx.deregister_table("t")?; + Ok(()) +} + +/// This UDF is to test extension handling, both on the input and output +/// sides. For the input, we will handle the data differently if there is +/// the canonical extension type Bool8. For the output we will add a +/// user defined extension type. +#[derive(Debug)] +struct ExtensionBasedUdf { + name: String, + signature: Signature, +} + +impl Default for ExtensionBasedUdf { + fn default() -> Self { + Self { + name: "canonical_extension_udf".to_string(), + signature: Signature::exact(vec![DataType::Int8], Volatility::Immutable), Review Comment: I *think* it's okay to leave signature as is, because that's mostly going to be used to determine if we need type coercion. In the call to `return_field_from_args` the user will have the opportunity to check the full metadata *after* any coercion has happened and can then fail the plan at that point. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org