gstvg opened a new pull request, #18921: URL: https://github.com/apache/datafusion/pull/18921
This PR adds support for lambdas with column capture and the array_transform scalar function which is used to test the lambda implementation. The changes are extensive, across various parts of the codebase, mostly tree traversals. This text aims to justify those changes, and show alternatives which may require less changes although not without trade-offs, so we can decide whether to move this forward or not, and if so, with what approach. For those who want to take a look at the code, don't waste time in the second commit as it is just adding a new field to a struct. [There's a build of the documentation of this branch available online](https://gstvg.github.io/datafusion/index.html?search=lambda) Example array_transform usage: ```sql array_transform([1, 2], v -> v*2) [2, 4] ``` <details><summary>Lambda Logical and Physical Representation</summary> ```rust pub struct Lambda { pub params: Vec<String>, // in the example above, vec!["v"] pub body: Box<Expr>, // in the example above, v*2 } pub struct LambdaExpr { params: Vec<String>, // in the example above, vec!["v"] body: Arc<dyn PhysicalExpr>, // in the example above, v * 2 } ``` </details> # Changes in ScalarUDF[Impl] to support lambdas Since lambda parameters are defined by the UDF implementation, datafusion doesn't know the type, nullability nor its metadata. So, we add a method to ScalarUDF[Impl], where the implementation returns a Field for each parameter supported for each of its lambdas <details ><summary>ScalarUDF[Impl] lambdas_parameters method</summary> ```rust struct/trait ScalarUDF[Impl] { /// Returns the parameters that any lambda supports fn lambdas_parameters( &self, args: &[ValueOrLambdaParameter], ) -> Result<Vec<Option<Vec<Field>>>> { Ok(vec![None; args.len()]) } } pub enum ValueOrLambdaParameter<'a> { /// A columnar value with the given field Value(FieldRef), /// A lambda with the given parameters names and a flag indicating whether it captures any columns Lambda(&'a [String], bool), } ``` </details> <details><summary>ArrayTransform lambdas_parameters implementation</summary> ```rust impl ScalarUDFImpl for ArrayTransform { fn lambdas_parameters( &self, args: &[ValueOrLambdaParameter], ) -> Result<Vec<Option<Vec<Field>>>> { let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda(_, _)] = args else { return exec_err!( "{} expects a value followed by a lambda, got {:?}", self.name(), args ); }; let (field, index_type) = match list.data_type() { DataType::List(field) => (field, DataType::Int32), DataType::LargeList(field) => (field, DataType::Int64), DataType::FixedSizeList(field, _) => (field, DataType::Int32), _ => return exec_err!("expected list, got {list}"), }; // we don't need to omit the index in the case the lambda don't specify, e.g. array_transform([], v -> v*2), // nor check whether the lambda contains more than two parameters, e.g. array_transform([], (v, i, j) -> v+i+j), // as datafusion will do that for us let value = Field::new("value", field.data_type().clone(), field.is_nullable()) .with_metadata(field.metadata().clone()); let index = Field::new("index", index_type, false); Ok(vec![None, Some(vec![value, index])]) } } ``` </details> <br> In order for the UDF to be able to compute its return field, it usually needs to know the return field of its lambdas, which is the case for array_transform. Because lambdas may capture columns, to compute the return field of a lambda, it's necessary to know the Field of the captured columns, which is only available in the schema and it's not passed as argument to return_fields[_from_args]. To avoid changing that, we internally use the fields returned in the newly added method ScalarUDFImpl::lambdas_parameters paired with the schema to compute the return field of the lambdas, and pass them in ReturnFieldArgs.arg_fields. We also add a slice of bools indicating if the arg in the position `i` is a lambda or not. Finally, we add a helper method to_lambda_args which merges the data in arg_fields and lambdas into a vec of ValueOrField enums allowing convenient pattern matching instead of having to inspect both arg_fields and lambdas. <br> <details ><summary>ReturnFieldArgs changes</summary> ```rust pub struct ReturnFieldArgs<'a> { /// The data types of the arguments to the function /// /// If argument `i` to the function is a lambda, it will be the return field of the /// lambda expression when evaluated with the arguments returned from /// ScalarUDFImpl::lambdas_parameters` /// /// For example, with `array_transform([1], v -> v == 5)` /// this field will be `[Field::new("", DataType::List(DataType::Int32), false), Field::new("", DataType::Boolean, false)]` pub arg_fields: &'a [FieldRef], ... skipped fields /// Is argument `i` to the function a lambda? /// /// For example, with `array_transform([1], v -> v == 5)` /// this field will be `[false, true]` pub lambdas: &'a [bool], } /// A tagged Field indicating whether it correspond to a value or a lambda argument #[derive(Debug)] pub enum ValueOrLambdaField<'a> { /// The Field of a ColumnarValue argument Value(&'a FieldRef), /// The Field of the return of the lambda body when evaluated with the parameters from ScalarUDF::lambda_parameters Lambda(&'a FieldRef), } impl<'a> ReturnFieldArgs<'a> { /// Based on self.lambdas, encodes self.arg_fields to tagged enums /// indicating whether it correspond to a value or a lambda argument pub fn to_lambda_args(&self) -> Vec<ValueOrLambdaField<'a>> { std::iter::zip(self.arg_fields, self.lambdas) .map(|(field, is_lambda)| { if *is_lambda { ValueOrLambdaField::Lambda(field) } else { ValueOrLambdaField::Value(field) } }) .collect() } } ``` </details> <details><summary>ArrayTransform return_field_from_args implementation</summary> ```rust impl ScalarUDFImpl for ArrayTransform { fn return_field_from_args( &self, args: datafusion_expr::ReturnFieldArgs, ) -> Result<Arc<Field>> { let args = args.to_lambda_args(); let [ValueOrLambdaField::Value(list), ValueOrLambdaField::Lambda(lambda)] = take_function_args(self.name(), &args)? else { return exec_err!( "{} expects a value followed by a lambda, got {:?}", self.name(), args ); }; // lambda is the resulting field of executing the lambda body // with the parameters returned in lambdas_parameters let field = Arc::new(Field::new( Field::LIST_FIELD_DEFAULT_NAME, lambda.data_type().clone(), lambda.is_nullable(), )); let return_type = match list.data_type() { DataType::List(_) => DataType::List(field), DataType::LargeList(_) => DataType::LargeList(field), DataType::FixedSizeList(_, size) => DataType::FixedSizeList(field, *size), _ => exec_err!("expected list, got {list}")?, }; Ok(Arc::new(Field::new("", return_type, list.is_nullable()))) } } ``` </details> <br> For execution, we add the lambdas fields to ScalarFunctionArgs and the helper method to_lambda_args, similar to the changes in ReturnFieldArgs and its to_lambda_args: <br> <details><summary>ScalarFunctionArgs changes</summary> ```rust pub struct ScalarFunctionArgs { /// The evaluated arguments to the function /// If it's a lambda, will be `ColumnarValue::Scalar(ScalarValue::Null)` /// /// For example, with `array_transform([1], v -> v == 5)` /// this field will be `[ColumnarValue::Scalar(ScalarValue::List([1])), ColumnarValue::Scalar(ScalarValue::Null)]` pub args: Vec<ColumnarValue>, /// Field associated with each arg, if it exists pub arg_fields: Vec<FieldRef>, .... /// The lambdas passed to the function /// If it's not a lambda it will be `None` /// /// For example, with `array_transform([1], v -> v == 5)` /// this field will be `[None, Some(...)]` pub lambdas: Option<Vec<Option<ScalarFunctionLambdaArg>>>, } /// A lambda argument to a ScalarFunction #[derive(Clone, Debug)] pub struct ScalarFunctionLambdaArg { /// The parameters defined in this lambda /// /// For example, for `array_transform([2], v -> -v)`, /// this will be `vec![Field::new("v", DataType::Int32, true)]` pub params: Vec<FieldRef>, /// The body of the lambda /// /// For example, for `array_transform([2], v -> -v)`, /// this will be the physical expression of `-v` pub body: Arc<dyn PhysicalExpr>, /// A RecordBatch containing at least the captured columns inside this lambda body, if any /// Note that it may contain additional, unspecified columns, but that's an implementation detail /// /// For example, for `array_transform([2], v -> v + a + b)`, /// this will be a `RecordBatch` with at least columns `a` and `b` pub captures: Option<RecordBatch>, } impl ScalarFunctionArgs { pub fn to_lambda_args(&self) -> Vec<ValueOrLambda<'_>> { match &self.lambdas { Some(lambdas) => std::iter::zip(&self.args, lambdas) .map(|(arg, lambda)| match lambda { Some(lambda) => ValueOrLambda::Lambda(lambda), None => ValueOrLambda::Value(arg), }) .collect(), None => self.args.iter().map(ValueOrLambda::Value).collect(), } } } /// An argument to a higher-order scalar function pub enum ValueOrLambda<'a> { Value(&'a ColumnarValue), Lambda(&'a ScalarFunctionLambdaArg), } ``` </details> <details><summary>ArrayTransform invoke_with_args implementation</summary> ```rust impl ScalarUDFImpl for ArrayTransform { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { // args.lambda_args allows the convenient match below, instead of inspecting both args.args and args.lambdas let lambda_args = args.to_lambda_args(); let [list_value, lambda] = take_function_args(self.name(), &lambda_args)?; let (ValueOrLambda::Value(list_value), ValueOrLambda::Lambda(lambda)) = (list_value, lambda) else { return exec_err!( "{} expects a value followed by a lambda, got {:?}", self.name(), &lambda_args ); }; let list_array = list_value.to_array(args.number_rows)?; // if any column got captured, we need to adjust it to the values arrays, // duplicating values of list with multiple values and removing values of empty lists // list_indices is not cheap so is important to avoid it when no column is captured let adjusted_captures = lambda .captures .as_ref() .map(|captures| take_record_batch(captures, &list_indices(&list_array)?)) .transpose()?; // use closures and merge_captures_with_lazy_args so that it calls only the needed ones based on the number of arguments // avoiding unnecessary computations let values_param = || Ok(Arc::clone(list_values(&list_array)?)); let indices_param = || elements_indices(&list_array); // the order of the merged schema is an unspecified implementation detail that may change in the future, // using this function is the correct way to merge as it return the correct ordering and will change in sync // the implementation without the need for fixes. It also computes only the parameters requested let lambda_batch = merge_captures_with_lazy_args( adjusted_captures.as_ref(), &lambda.params, // ScalarUDF already merged the fields returned in lambdas_parameters with the parameters names defined in the lambda, so we don't need to &[&values_param, &indices_param], )?; // call the transforming expression with the record batch composed of the list values merged with captured columns let transformed_values = lambda .body .evaluate(&lambda_batch)? .into_array(lambda_batch.num_rows())?; let field = match args.return_field.data_type() { DataType::List(field) | DataType::LargeList(field) | DataType::FixedSizeList(field, _) => Arc::clone(field), _ => { return exec_err!( "{} expected ScalarFunctionArgs.return_field to be a list, got {}", self.name(), args.return_field ) } }; let transformed_list = match list_array.data_type() { DataType::List(_) => { let list = list_array.as_list(); Arc::new(ListArray::new( field, list.offsets().clone(), transformed_values, list.nulls().cloned(), )) as ArrayRef } DataType::LargeList(_) => { let large_list = list_array.as_list(); Arc::new(LargeListArray::new( field, large_list.offsets().clone(), transformed_values, large_list.nulls().cloned(), )) } DataType::FixedSizeList(_, value_length) => { Arc::new(FixedSizeListArray::new( field, *value_length, transformed_values, list_array.as_fixed_size_list().nulls().cloned(), )) } other => exec_err!("expected list, got {other}")?, }; Ok(ColumnarValue::Array(transformed_list)) } } ``` </details> <br> # Changes in tree traversals Using this query as an example: ```sql create table t as select 1 as a, [[2, 3]] as b, 1 as c; select a, b, c, array_transform(b, (b, i) -> array_transform(b, b -> b + c + i)) from t; a | b | c | array_transform(b, (b, i) -> array_transform(b, b -> b + c + i)) 1, [[2, 3]], 1, [[4, 5]] ``` Detailing the identifiers in the query: ```sql definition of lambda first parameter, arbitrally named "b", the element of the array being transformed. In the example, a List(Int32) column of [2, 3] . shadows t.b ^ | | definition of lambda second <--+ | parameter arbitrally named "i" | | the 1-based index of the element | definition of the only parameter of the lambda, | being transformed: in the | arbitrally named "b", the element of the array | example, a Int32 column of "1" | being transformed. Shadows the parameter "b" | | from outer lambda. The index parameter is omitted. | | In the example, a Int32 column with values "2, 3" +-------------------------------+ | ^ | | | | | | select a, b, c, array_transform(b, (b, i) -> array_transform(b, b -> b + c + i)) from t; | | | | | | | | | | | v | | | | | | | column "b" from table "t", t.b being passed | | | | | | | as argument to outer array_transform | | | | | | v | | | | | | projection of column "c" from table "t", t.c | | | | | | | | | | | v | | | | | projection of column "b" from table "t", t.b | | | | | | | | | v | | | | projection of column "a" from table "t", t.a | | | | | | | | v | | | parameter "b" from outer lambda being passed | | | as argument to the inner array_transform | | | v | | reference to parameter "b" from inner lambda | | | | v | reference to column "c" captured from table "t", t.c | | v reference to parameter "i" captured from outer lambda ``` Note that: 1: lambdas may be nested 2: lambdas may capture columns from the outer scope, be it from the input plan or from another, outer lambdas. 3: lambdas introduces parameters that shadows columns from the outer scope 4: lambdas may support multiple parameters, and it is possible to omit the trailing ones that aren't used. Omitting unnecessary parameters positioned before an used parameter is currently not supported and may incur unnecessary computations # Representing columns referring lambda parameters while being able to differentiate them from regular columns in Expr tree traversals Because they are identical to regular columns, it is intuitive to use the same logical and physical expression to represent columns referring to lambdas parameters. However, the existing tree traversals were made without taking into account that a column may refer to a lambda parameter, and not a column from the input plan, and so they would behave erratically. In the example query, projection pushdown would try to push the lambda parameter "i", which won't exist in table "t". Another example: ```rust fn minimize_join_filter(expr: Arc<dyn PhysicalExpr>, ...) -> JoinFilter { let mut used_columns = HashSet::new(); expr.apply(|expr| { if let Some(col) = expr.as_any().downcast_ref::<Column>() { // if this is a lambda column, code below will break used_columns.insert(col.index()); } Ok(TreeNodeRecursion::Continue) }); ... } ``` Therefore, we either make available a way to differentiate them, or use two different expressions: ## Option 1. Use the same Column expression, differentiate with a new set of TreeNode methods, *_with_lambdas_params This PR uses the existing column expression, always unqualified, and adds a new set of [TreeNode-like methods](https://gstvg.github.io/datafusion/index.html?search=with_lambdas_params) on expressions that starts traversals with an empty HashSet<String>, and while traversing the expr tree, when finding a lambda, clone the set and adds the lambda parameters to it, and pass it to the visiting/transforming closure so that it can differentiate the columns ```rust impl Expr { pub fn transform_with_lambdas_params< F: FnMut(Self, &HashSet<String>) -> Result<Transformed<Self>>, >( self, mut f: F, ) -> Result<Transformed<Self>> {} } ``` <details><summary>Expr tree traversal with lambdas_params. This query is a modified version of the example query where the inner lambda second parameter is used</summary> ```sql ┌╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶┐ ╷ array_transform(b, (b, i) -> array_transform(b, (b, j) -> b + i + j)) ╷ ╷ ╷ ╷ ╷ ╷ lambdas_params = {} ╷ └╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶┘ │ ▼ ┌╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶┐ ╷ (b, i) -> array_transform(b, (b, j) -> b + i + j) ╷ ╷ ╷ ╷ ╷ ╷ lambdas_params = { b, i } ╷ └╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶┘ │ ▼ ┌╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶┐ ╷ array_transform(b, (b, j) -> b + i + j) ╷ ╷ ╷ ╷ ╷ ╷ lambdas_params = { b, i } ╷ └╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶┘ │ ▼ ┌╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶┐ ╷ (b, j) -> b + i + j ╷ ╷ ╷ ╷ ╷ ╷ lambdas_params = { b, i, j } ╷ └╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶╶┘ ``` </details> How minimize_join_filter would looks like: ```rust fn minimize_join_filter(expr: Arc<dyn PhysicalExpr>, ...) -> JoinFilter { let mut used_columns = HashSet::new(); expr.apply_with_lambdas_params(|expr, lambdas_params| { if let Some(col) = expr.as_any().downcast_ref::<Column>() { // dont include lambdas parameters if !lambdas_params.contains(col.name()) { used_columns.insert(col.index()); } } Ok(TreeNodeRecursion::Continue) }) ... } ``` I started with this option without an idea of how big it would be. It requires using the new TreeNode methods and checking the column in 30+ tree traversals, and so will downstream too. I choose to keep it in this PR so that we only choose/discard this knowing how big it is. ## Option 2. New Expr LambdaColumn Create a new Expr, LambdaColumn, which doesn't require new TreeNode methods, but requires that expr_api users use this new expr when applicable. It requires a fix in expr_simplifier for expressions with a lambda column, and may require similar work in downstream. ```rust struct LambdaColumn { name: String, spans: Spans, } struct LambdaColumn { name: String, index: usize, } ``` Existing code inalterated ```rust fn minimize_join_filter(expr: Arc<dyn PhysicalExpr>, ...) -> JoinFilter { let mut used_columns = HashSet::new(); expr.apply(|expr| { if let Some(col) = expr.as_any().downcast_ref::<Column>() { //no need to check because Column is never a LambdaColumn used_columns.insert(col.index()); } Ok(TreeNodeRecursion::Continue) }) ... } ``` ## Option 3. Add is_lambda_parameter boolean field to Column Add is_lambda_parameter to the existing column expression. It won't require new TreeNode methods, but still requires checking the field everywhere that new TreeNode methods are currently used in this PR ```rust //logical struct Column { pub relation: Option<TableReference>, pub name: String, pub spans: Spans, pub is_lambda_parameter: bool, } //physical struct Column { name: String, index: usize, is_lambda_parameter: bool, } ``` How minimize_join_filter would look like: ```rust fn minimize_join_filter(expr: Arc<dyn PhysicalExpr>, ...) -> JoinFilter { let mut used_columns = HashSet::new(); expr.apply(|expr| { if let Some(col) = expr.as_any().downcast_ref::<Column>() { // dont include lambdas parameters if !col.is_lambda_parameter { used_columns.insert(col.index()); } } Ok(TreeNodeRecursion::Continue) }) ... } ``` ### Comparison between options: | | Expr::Column/ColumnExpr | Expr::LambdaColumn/LambdaColumnExpr | is_lambda_parameter flag in Column | |-------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------| | Internal code change | New TreeNode methods, use them where applicable | Add new expr. Use the correct expr while parsing sql(instrumentation already exists to skip normalization of lambdas parameters in sql parsing) | Check the field where applicable. Set the field while parsing sql | | Downstream code change | If lambda support is desired, use the new TreeNode where applicable. Otherwise none | None | If lambda support is desired, check the field where applicable Otherwise none | | Create a new Expr | No | Yes | No | | Requires new TreeNode methods | Yes, _with_lambdas_params for both logical and physical expressions | No | No | | Inhibits existing optimizations for exprs with columns of lambdas parameters | No | expr_simplifier, fixable. But may happen in downstream too | No | | expr_api users must reason about | No | Yes, use the correct expression type | Yes, set the flag | | Two almost identical Expressions which major difference is the place in the expr tree that they exist | No | Yes | No | | Data from the tree leaks into the column node | No | No | Yes, is_lambda_parameter is a information not about the column node itself, but about its place in the tree | Ultimately, I don't have an inclination for any option and I believe it's a decision up to maintainers and those with more contact with downstream users, who have more idea of what option is easier to use. I think that the most laborious being already implemented puts us in a good position to make a choice, and would be easy to change to another option. Continue in the comment below. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
