benbellick commented on code in PR #21193: URL: https://github.com/apache/datafusion/pull/21193#discussion_r3219978510
##########
datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs:
##########
@@ -35,7 +38,68 @@ pub fn from_higher_order_function(
fun: &expr::HigherOrderFunction,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
- from_function(producer, fun.name(), &fun.args, schema)
+ let mut lambda_parameters = fun.lambda_parameters(schema)?.into_iter();
+
+ let num_lambdas = fun
+ .args
+ .iter()
+ .filter(|arg| matches!(arg, Expr::Lambda(_)))
+ .count();
+
+ if lambda_parameters.len() != num_lambdas {
+ return substrait_err!(
+ "{} returned {} lambdas but {num_lambdas} expected",
+ fun.name(),
+ lambda_parameters.len()
+ );
+ }
+
+ let arguments = fun
+ .args
+ .iter()
+ .map(|arg| {
+ let arg = match arg {
+ Expr::Lambda(l) => {
+ let lambda_parameters =
+ lambda_parameters.next().ok_or_else(|| {
+ internal_datafusion_err!(
+ "lambda_parameters len should have been
checked above"
+ )
+ })?;
+
+ let named_lambda_parameters =
+ std::iter::zip(&l.params, lambda_parameters)
Review Comment:
What happens if the length of `l.params` and `lambda_parameters` differ? Can
this happen?
##########
datafusion/substrait/tests/cases/serialize.rs:
##########
@@ -196,6 +205,176 @@ mod tests {
panic!("plan did not match expected structure")
}
+ #[tokio::test]
+ async fn higher_order_function42() -> Result<()> {
Review Comment:
I don't dispute that the implementation is correct, but this test is tough
to read.
My understanding is that roundtrip tests generally catch most consistency
issues, and so the goal of tests here is to specifically validate the things
which could be incorrect but difficult to validate via roundtrip tests.
What do you think about focusing exclusively on the lambda parameter refs?
Something like:
```rust
#[tokio::test]
async fn serialize_nested_lambda_references() -> Result<()> {
let ctx = higher_order_function_ctx().await?;
let df = ctx
.sql(
"SELECT array_transform2(
[[data3.p1]],
(v, i) -> array_transform2(v, (v, j) -> v * i * j)
) FROM data3",
)
.await?;
let datafusion_plan = df.into_optimized_plan()?;
let plan = to_substrait_plan(&datafusion_plan, &ctx.state())?
.as_ref()
.clone();
assert_eq!(
lambda_param_refs_in_expression(project_expression(&plan, 0)),
vec![
// inner array_transform2 argument: outer v
(0, 0),
// inner lambda body: v * i * j
(0, 0),
(1, 1),
(0, 1),
]
);
Ok(())
}
```
##########
datafusion/substrait/src/logical_plan/consumer/expr/lambda.rs:
##########
Review Comment:
Should we include tests here to check that the following two things result
in an error?
- missing parameters or body
- invalid steps_out or field index
Or somewhere else if you think it would be more appropriate.
##########
datafusion/substrait/tests/testdata/test_plans/higher_order_function.json:
##########
@@ -0,0 +1,338 @@
+{
+ "version": {
+ "minorNumber": 85,
+ "producer": "datafusion"
+ },
+ "extensions": [
+ {
+ "extensionFunction": {
+ "extensionUrnReference": 4294967295,
+ "functionAnchor": 2,
+ "name": "array_transform2"
+ }
+ },
+ {
+ "extensionFunction": {
+ "extensionUrnReference": 4294967295,
+ "name": "make_array"
+ }
+ },
+ {
+ "extensionFunction": {
+ "extensionUrnReference": 4294967295,
+ "functionAnchor": 3,
+ "name": "array_concat"
+ }
+ },
+ {
+ "extensionFunction": {
+ "extensionUrnReference": 4294967295,
+ "functionAnchor": 1,
+ "name": "multiply"
+ }
+ }
+ ],
+ "relations": [
+ {
+ "root": {
+ "input": {
+ "project": {
+ "common": {
+ "emit": {
+ "outputMapping": [
+ 1
+ ]
+ }
+ },
+ "input": {
+ "read": {
+ "baseSchema": {
+ "names": [
+ "p1"
+ ],
+ "struct": {
+ "types": [
+ {
+ "i64": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ }
+ ],
+ "nullability": "NULLABILITY_REQUIRED"
+ }
+ },
+ "projection": {
+ "select": {
+ "structItems": [
+ {}
+ ]
+ }
+ },
+ "namedTable": {
+ "names": [
+ "data3"
+ ]
+ }
+ }
+ },
+ "expressions": [
+ {
+ "scalarFunction": {
+ "functionReference": 2,
+ "arguments": [
+ {
+ "value": {
+ "scalarFunction": {
+ "arguments": [
+ {
+ "value": {
+ "scalarFunction": {
+ "arguments": [
+ {
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {}
+ },
+ "rootReference": {}
+ }
+ }
+ }
+ ]
+ }
+ }
+ }
+ ]
+ }
+ }
+ },
+ {
+ "value": {
+ "lambda": {
+ "parameters": {
+ "types": [
+ {
+ "list": {
+ "type": {
+ "i64": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ {
+ "i64": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ }
+ ],
+ "nullability": "NULLABILITY_REQUIRED"
+ },
+ "body": {
+ "scalarFunction": {
+ "functionReference": 3,
+ "arguments": [
+ {
+ "value": {
+ "scalarFunction": {
+ "functionReference": 2,
+ "arguments": [
+ {
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {}
+ },
+ "lambdaParameterReference": {}
+ }
+ }
+ },
+ {
+ "value": {
+ "lambda": {
+ "parameters": {
+ "types": [
+ {
+ "i64": {
+ "nullability":
"NULLABILITY_NULLABLE"
+ }
+ },
+ {
+ "i64": {
+ "nullability":
"NULLABILITY_NULLABLE"
+ }
+ }
+ ],
+ "nullability":
"NULLABILITY_REQUIRED"
+ },
+ "body": {
+ "scalarFunction": {
+ "functionReference": 1,
+ "arguments": [
+ {
+ "value": {
+ "scalarFunction": {
+ "functionReference":
1,
+ "arguments": [
+ {
+ "value": {
+ "selection": {
+
"directReference": {
+
"structField": {}
+ },
+
"lambdaParameterReference": {}
+ }
+ }
+ },
+ {
+ "value": {
+ "selection": {
+
"directReference": {
+
"structField": {
+ "field":
1
+ }
+ },
+
"lambdaParameterReference": {
+
"stepsOut": 1
+ }
+ }
+ }
+ }
+ ]
+ }
+ }
+ },
+ {
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 1
+ }
+ },
+
"lambdaParameterReference": {}
+ }
+ }
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ ]
+ }
+ }
+ },
+ {
+ "value": {
+ "scalarFunction": {
+ "functionReference": 2,
+ "arguments": [
+ {
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {}
+ },
+ "lambdaParameterReference": {}
+ }
+ }
+ },
+ {
+ "value": {
+ "lambda": {
+ "parameters": {
+ "types": [
+ {
+ "i64": {
+ "nullability":
"NULLABILITY_NULLABLE"
+ }
+ },
+ {
+ "i64": {
+ "nullability":
"NULLABILITY_NULLABLE"
+ }
+ }
+ ],
+ "nullability":
"NULLABILITY_REQUIRED"
+ },
+ "body": {
+ "scalarFunction": {
+ "functionReference": 1,
+ "arguments": [
+ {
+ "value": {
+ "scalarFunction": {
+ "functionReference":
1,
+ "arguments": [
+ {
+ "value": {
+ "selection": {
+
"directReference": {
+
"structField": {}
+ },
+
"lambdaParameterReference": {}
+ }
+ }
+ },
+ {
+ "value": {
+ "selection": {
+
"directReference": {
+
"structField": {
+ "field":
1
+ }
+ },
+
"lambdaParameterReference": {
+
"stepsOut": 1
+ }
+ }
+ }
+ }
+ ]
+ }
+ }
+ },
+ {
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 1
+ }
+ },
+
"lambdaParameterReference": {}
+ }
+ }
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ ]
+ }
+ }
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ ]
+ }
+ }
+ ]
+ }
+ },
+ "names": [
+ "array_transform2(make_array(make_array(data3.p1)),(v, i) ->
array_concat(array_transform2(v,(v, j) -> v * i * j),array_transform2(v,(v, j)
-> v * i * j)))"
+ ]
+ }
+ }
+ ]
+}
Review Comment:
There are a few things that were wrong with this plan:
- missing URN declarations
- missing `outputType` in scalar function invocations
The Substrait validation in DataFusion does not catch these issues, but in
the interest of keeping checked-in Substrait fixtures structurally valid, I
think it is better to use this updated plan.
There are a few other validity issues I noticed, but I don't think this PR
needs to solve the broader DataFusion Substrait extension story:
- Function names in Substrait extension declarations should be signatures,
e.g. `add:i8_i8` rather than just `add`. [DataFusion has permissive handling
for this on the consumer
side](https://github.com/apache/datafusion/blob/2c234394f52d44a0f5b0011773f0ecd720a19625/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs#L97-L105),
but I think that should be fixed separately.
- Some referenced functions here, such as `array_transform2`, `make_array`,
and `array_concat`, do not correspond to extension YAML declarations. That also
seems like a broader DataFusion Substrait issue and can be resolved later.
```suggestion
{
"version": {
"minorNumber": 85,
"producer": "datafusion"
},
"extensions": [
{
"extensionFunction": {
"extensionUrnReference": 2,
"functionAnchor": 2,
"name": "array_transform2"
}
},
{
"extensionFunction": {
"extensionUrnReference": 2,
"name": "make_array"
}
},
{
"extensionFunction": {
"extensionUrnReference": 2,
"functionAnchor": 3,
"name": "array_concat"
}
},
{
"extensionFunction": {
"extensionUrnReference": 1,
"functionAnchor": 1,
"name": "multiply"
}
}
],
"relations": [
{
"root": {
"input": {
"project": {
"common": {
"emit": {
"outputMapping": [
1
]
}
},
"input": {
"read": {
"baseSchema": {
"names": [
"p1"
],
"struct": {
"types": [
{
"i64": {
"nullability": "NULLABILITY_NULLABLE"
}
}
],
"nullability": "NULLABILITY_REQUIRED"
}
},
"projection": {
"select": {
"structItems": [
{}
]
}
},
"namedTable": {
"names": [
"data3"
]
}
}
},
"expressions": [
{
"scalarFunction": {
"functionReference": 2,
"arguments": [
{
"value": {
"scalarFunction": {
"arguments": [
{
"value": {
"scalarFunction": {
"arguments": [
{
"value": {
"selection": {
"directReference": {
"structField": {}
},
"rootReference": {}
}
}
}
],
"outputType": {
"list": {
"type": {
"i64": {
"nullability":
"NULLABILITY_NULLABLE"
}
},
"nullability": "NULLABILITY_NULLABLE"
}
}
}
}
}
],
"outputType": {
"list": {
"type": {
"list": {
"type": {
"i64": {
"nullability": "NULLABILITY_NULLABLE"
}
},
"nullability": "NULLABILITY_NULLABLE"
}
},
"nullability": "NULLABILITY_NULLABLE"
}
}
}
}
},
{
"value": {
"lambda": {
"parameters": {
"types": [
{
"list": {
"type": {
"i64": {
"nullability": "NULLABILITY_NULLABLE"
}
},
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"i64": {
"nullability": "NULLABILITY_NULLABLE"
}
}
],
"nullability": "NULLABILITY_REQUIRED"
},
"body": {
"scalarFunction": {
"functionReference": 3,
"arguments": [
{
"value": {
"scalarFunction": {
"functionReference": 2,
"arguments": [
{
"value": {
"selection": {
"directReference": {
"structField": {}
},
"lambdaParameterReference": {}
}
}
},
{
"value": {
"lambda": {
"parameters": {
"types": [
{
"i64": {
"nullability":
"NULLABILITY_NULLABLE"
}
},
{
"i64": {
"nullability":
"NULLABILITY_NULLABLE"
}
}
],
"nullability":
"NULLABILITY_REQUIRED"
},
"body": {
"scalarFunction": {
"functionReference": 1,
"arguments": [
{
"value": {
"scalarFunction": {
"functionReference": 1,
"arguments": [
{
"value": {
"selection":
{
"directReference": {
"structField": {}
},
"lambdaParameterReference": {}
}
}
},
{
"value": {
"selection":
{
"directReference": {
"structField": {
"field": 1
}
},
"lambdaParameterReference": {
"stepsOut": 1
}
}
}
}
],
"outputType": {
"i64": {
"nullability":
"NULLABILITY_NULLABLE"
}
}
}
}
},
{
"value": {
"selection": {
"directReference":
{
"structField": {
"field": 1
}
},
"lambdaParameterReference": {}
}
}
}
],
"outputType": {
"i64": {
"nullability":
"NULLABILITY_NULLABLE"
}
}
}
}
}
}
}
],
"outputType": {
"list": {
"type": {
"i64": {
"nullability":
"NULLABILITY_NULLABLE"
}
},
"nullability":
"NULLABILITY_NULLABLE"
}
}
}
}
},
{
"value": {
"scalarFunction": {
"functionReference": 2,
"arguments": [
{
"value": {
"selection": {
"directReference": {
"structField": {}
},
"lambdaParameterReference": {}
}
}
},
{
"value": {
"lambda": {
"parameters": {
"types": [
{
"i64": {
"nullability":
"NULLABILITY_NULLABLE"
}
},
{
"i64": {
"nullability":
"NULLABILITY_NULLABLE"
}
}
],
"nullability":
"NULLABILITY_REQUIRED"
},
"body": {
"scalarFunction": {
"functionReference": 1,
"arguments": [
{
"value": {
"scalarFunction": {
"functionReference": 1,
"arguments": [
{
"value": {
"selection":
{
"directReference": {
"structField": {}
},
"lambdaParameterReference": {}
}
}
},
{
"value": {
"selection":
{
"directReference": {
"structField": {
"field": 1
}
},
"lambdaParameterReference": {
"stepsOut": 1
}
}
}
}
],
"outputType": {
"i64": {
"nullability":
"NULLABILITY_NULLABLE"
}
}
}
}
},
{
"value": {
"selection": {
"directReference":
{
"structField": {
"field": 1
}
},
"lambdaParameterReference": {}
}
}
}
],
"outputType": {
"i64": {
"nullability":
"NULLABILITY_NULLABLE"
}
}
}
}
}
}
}
],
"outputType": {
"list": {
"type": {
"i64": {
"nullability":
"NULLABILITY_NULLABLE"
}
},
"nullability":
"NULLABILITY_NULLABLE"
}
}
}
}
}
],
"outputType": {
"list": {
"type": {
"i64": {
"nullability": "NULLABILITY_NULLABLE"
}
},
"nullability": "NULLABILITY_NULLABLE"
}
}
}
}
}
}
}
],
"outputType": {
"list": {
"type": {
"list": {
"type": {
"i64": {
"nullability": "NULLABILITY_NULLABLE"
}
},
"nullability": "NULLABILITY_NULLABLE"
}
},
"nullability": "NULLABILITY_NULLABLE"
}
}
}
}
]
}
},
"names": [
"array_transform2(make_array(make_array(data3.p1)),(v, i) ->
array_concat(array_transform2(v,(v, j) -> v * i * j),array_transform2(v,(v, j)
-> v * i * j)))"
]
}
}
],
"extensionUrns": [
{
"extensionUrnAnchor": 1,
"urn": "extension:io.substrait:functions_arithmetic"
},
{
"extensionUrnAnchor": 2,
"urn": "extension:io.substrait:functions_list"
}
]
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]
