gstvg commented on PR #18921:
URL: https://github.com/apache/datafusion/pull/18921#issuecomment-3869546935
@linhr
Really really sorry for the delay
<br>
> I think it's possible to resolve lambda parameters and remove them from
the arguments list with `ScalarUDF` itself, with a few cons compared to a
dedicated to `LambdaUDF`, but with much less code changes and similar enough to
compare the alternatives. I believe I'm already finishing it, and will push to
another branch soon.
Unfortunately, trying to resolve lambdas with `ScalarUDF` didn't work well.
So, before moving to a `LambdaUDF` based implementation, I would like to
further discuss the current approach.
<br>
>`Expr::Lambda` and `Expr::LambdaColumn` are only fragments whose data types
etc. are not well-defined which may be challenging to work with in other parts
of the library (e.g. `ExprSchemable`).
>>Yeah, this is the major challenge in this PR. Currently, `Expr::Lambda`
always return `DataType::Null` and `Expr::LambdaVariable` embeds a `FieldRef`
which is used to implement the `ExprSchemable` and don't even look at the
`DFSchema`
I mean, this *was* the major challenge of the PR on it's first
implementation, now that we use `LambdaVariable` with a `Field`, I consider
this to be solved. And currently `Expr::Lambda` actually returns the
`return_field` of it's body, but I believe we can also return a field with
`DataType::Null` if deemed better.
<br>
> `LambdaFunction` can represent a "resolved" lambda function call in the
logical plan. In contrast, `Expr::Lambda` and `Expr::LambdaColumn` are only
fragments whose data types etc. are not well-defined which may be challenging
to work with in other parts of the library (e.g. `ExprSchemable`).
... when the function is actually "invoked" with Arrow arrays, the lambda
parameter should have been resolved and removed from the argument list
Sorry, at first I thought this only means resolving/embedding the lambda
body within the UDF itself and removing it from
`Expr::ScalarFunction.args`/`LambdaFunction.args`, and also removing
`Expr::Lambda` variant from the logical Expr enum and it's physical
counterpart, but now I'm not sure it also means omitting the lambda body from
`TreeNode` traversals?
To streamline the discussion, I have a few comments for each of these
options, if they're applicable:
## Omitting the lambda body from TreeNode traversals
At least the following traversals currently use `TreeNode` and would require
adjust if the body is omitted:
Projection pushdown for column capture, type coercion for correctness, expr
simplifier and CSE for performance.
Being only 4 internal traversals, is easy to specially handle them, but
since DF is an open system, I believe there should be a way to downstream users
to visit/transform lambdas expressions. We could document that they need to be
specially handled, outside `TreeNode`, or we could offer an API for it, and
after all, I believe the ideal API would be exactly the `TreeNode` API.
In my first iteration of this PR, the "outdated" sections of the description
and of the first comment, I manually checked every `TreeNode` usage on logical
and physical expressions, and find none where `Expr::Lambda` should be
specially handled(it is simply ignored), and `Expr::LambdaVariable` only
required handling in few places.
## Removing Expr::Lambda variant from Expr enum and it's physical counterpart
Removing `Expr::Lambda` from AST makes some tree traversals more difficult,
while keeping it doesn't make difference: it's simply ignored as most
expressions in most traversals.
One example is `BindLambdaVariable` in this PR. Is not important to
understand what it does, just that it's bigger, with more boilerplate and
harder to read without `Expr::Lambda`:
```rust
struct BindLambdaVariable<'a> {
variables: HashMap<&'a str, (ArrayRef, usize)>,
//(variable value, number of times it has been shadowed by other lambdas
variables in inner scopes in the current position of the tree)
}
impl TreeNodeRewriter for BindLambdaVariable<'_> {
type Node = Arc<dyn PhysicalExpr>;
fn f_down(&mut self, node: Self::Node) ->
Result<Transformed<Self::Node>> {
if let Some(lambda_variable) =
node.as_any().downcast_ref::<LambdaVariable>() {
if let Some((value, shadows)) =
self.variables.get(lambda_variable.name()) {
if *shadows == 0 {
return Ok(Transformed::yes(Arc::new(
lambda_variable.clone().with_value(value.clone()),
)));
}
}
} else if let Some(inner_lambda) =
node.as_any().downcast_ref::<LambdaExpr>() {
for param in inner_lambda.params() {
if let Some((_value, shadows)) =
self.variables.get_mut(param.as_str()) {
*shadows += 1;
}
}
}
Ok(Transformed::no(node))
}
fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
if let Some(inner_lambda) =
node.as_any().downcast_ref::<LambdaExpr>() {
for param in inner_lambda.params() {
if let Some((_value, shadows)) =
self.variables.get_mut(param.as_str()) {
*shadows -= 1;
}
}
}
Ok(Transformed::no(node))
}
}
```
Implementation without Physical Lambda on AST
```rust
fn bind_lambda_variables(
node: Arc<dyn PhysicalExpr>,
params: &mut HashMap<&str, (ArrayRef, usize)>,
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
node.transform_down(|node| {
if let Some(lambda_variable) =
node.as_any().downcast_ref::<LambdaVariable>() {
if let Some((value, shadows)) =
params.get(lambda_variable.name()) {
if *shadows == 0 {
return Ok(Transformed::yes(Arc::new(
lambda_variable.clone().with_value(value.clone()),
)));
}
}
} else if let Some(fun) =
node.as_any().downcast_ref::<LambdaUDFExpr>() {
let mut transformed = false;
let new_lambdas = fun
.lambdas()
.iter()
.map(|(params_names, body)| {
for param in *params_names {
if let Some((_value, shadows)) =
params.get_mut(param.as_str()) {
*shadows += 1;
}
}
let new_body = bind_lambda_variables(Arc::clone(body),
params)?;
transformed |= new_body.transformed;
for param in *params_names {
if let Some((_value, shadows)) =
params.get_mut(param.as_str()) {
*shadows -= 1;
}
}
Ok((params_names.to_vec(), new_body.data))
})
.collect::<Result<_>>()?;
let new_args = fun
.args()
.iter()
.map(|arg| {
let new_arg = bind_lambda_variables(Arc::clone(arg),
params)?;
transformed |= new_arg.transformed;
Ok(new_arg.data)
})
.collect::<Result<_>>()?;
let fun = fun.with_new_children(new_args, new_lambdas);
return Ok(Transformed::new(
Arc::new(fun),
transformed,
TreeNodeRecursion::Stop,
));
}
Ok(Transformed::no(node))
})
}
```
Note that this transformation always return a constant `TreeNodeRecursion`
value. If not, `TreeNodeRecursion` would also needed to be manually handled,
adding more boilerplate that `TreeNode` handles automatically when
`Expr::Lambda` is present
## Resolve/embed lambdas on the UDF implementation itself(essentially
partition args from lambdas)
By partition args and lambdas, we lose positional data, which are necessary
to implement `SqlUnparser` and logical and physical formatting, leading to
either the implementation itself having to implement them, adding some
boilerplate, where today is automatically handled by `ScalarUDF` and
`SqlUnparser` itself, or that the implementation save and expose positional
data, that is them used by `ScalarUDF` and `SqlUnparser` to de-partition the
args list, and them proceed normally.
Some closed system without support for udf with lambdas use a common
positioning for lambdas and can just use them without having to store
positional data nor lambdas and other args in the same ordered list:
Always last for spark, duckdb and snowflake:
```sql
transform(array(1), v -> v*2) -- spark
list_transform([1, 2, NULL, 3], lambda x: x + 1) -- duckdb
transform([1, 2, 3], a INT -> a * 2) -- snowflake
```
Always first for clickhouse:
```sql
arrayMap(x -> (x + 2), [1, 2, 3])
```
But I think that datafusion as an open and extendable system shouldn't stick
to a single convention as different deployments may want to support one,
another or both conventions.
Furthermore, my interest in lambda functions is to implement union
manipulating udfs, which receives multiple lambdas interleaved with regular
arguments. Consider an `Union<'str' = Utf8, 'num' = Float64>`, that could be
transformed like this:
```sql
union_transform(
union_value,
'str', str -> trim(str),
'num', num -> num*2
)
```
where today the only way is with this(if/when support for `union_value`
lands):
```sql
case union_tag(union_value)
when 'str' then union_value('str', trim(union_extract(union_value,
'str')))
when 'num' then union_value('num', union_extract(union_value, 'num') * 2)
end
```
For such cases, storing positional data or manually implementing logical and
phyiscal formatting and `SqlUnparser` for every UDF is mandatory, as that
positioning doesn't fit the simple conventions like always first or last.
This also omit positional info from the implementation. None lambda udf
discussed so far, including `union_transform`, requires positional info for
evaluation, only for logical and physical formatting and `SqlUnparser`, but I
don't like the idea of blocking any future hypothetical udf that requires
positional info for evaluation. A theoretical example is `union_transform`
itself supporting a slight shorter syntax for lambdas that returns constants:
```sql
union_transform(
union_value,
'str', str -> trim(str),
'num', _num -> 0 -- lambda that returns constant/scalar
)
union_transform(
union_value,
'str', str -> trim(str),
'num', 0 -- shorter syntax, pass constant directly
)
```
That syntax does requires positional info for evaluation. It's really just
an example and I don't plan to support it on my `union_transform` implementation
Instead, if we add a new `LambdaFunction` expression and `LambdaUDF` trait,
the implementation could own both the lambdas *and* the other args, so that no
partitioning would occur:
```rust
struct LambdaFunction { invocation: Box<dyn LambdaInvocation> } //invocation
includes all args, both lambdas and non-lambdas
Expr::LambdaFunction(LambdaFuntion::new(ArrayTransform::new(all_args_including_lambdas)))
// instead of
struct LambdaFunction { func: Arc<dyn LambdaUDF>, args: Vec<Expr> } // func
owns only lambdas, and non lambdas are stored on args
Expr::LambdaFunction(LambdaFunction::new(ArrayTransform::new(lambda),
args_without_lambdas))
```
# Insights from Spark implementation
I'm don't have much experience with Spark and Scala, I only implemented a
custom data source a few years ago, so if something below seems wrong, it
probably is.
Spark includes both [HigherOrderFunction
trait](https://github.com/apache/spark/blob/eec092c9f9d1ad3df0390b1ea56340c1103ba887/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala#L153)
and [LambdaFunction
class](https://github.com/apache/spark/blob/eec092c9f9d1ad3df0390b1ea56340c1103ba887/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala#L113),
similar to our current `Expr::Lambda` (and not the proposed `LambdaFunction`),
as well as [NamedLambdaVariable
class](https://github.com/apache/spark/blob/eec092c9f9d1ad3df0390b1ea56340c1103ba887/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala#L77),
similar to ours `Expr::LambdaVariable`, all of them extending
[Expression](http://javadoc.io/static/org.apache.spark/spark-catalyst_2.10/2.2.3/index.html#org.apache.spark.sql.catalyst.expressions.Expression),
which in turn extends [TreeNode](https://www.java
doc.io/static/org.apache.spark/spark-catalyst_2.10/2.2.3/index.html#org.apache.spark.sql.catalyst.trees.TreeNode)
Some traversals branch directly on `LambdaFunction` expressions, which in my
view support the idea of both keeping `Expr::Lambda` and exposing the lambda
body on the `TreeNode` API:
[Lambda
Binder](https://github.com/apache/spark/blob/8b68a172d34d2ed9bd0a2deefcae1840a78143b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/LambdaBinder.scala#L56),
(similar to ours BindLambdaVariable)
[ResolveLambdaVariables](https://github.com/apache/spark/blob/8b68a172d34d2ed9bd0a2deefcae1840a78143b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala#L59)
[ColumnResolutionHelper](https://github.com/apache/spark/blob/8b68a172d34d2ed9bd0a2deefcae1840a78143b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala#L121)
[ColumnNodeToExpressionConverter](https://github.com/apache/spark/blob/f80b919d19ce8e390c029857802cfe40190bb5ea/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala#L143)
[CheckAnalysis](https://github.com/apache/spark/blob/f80b919d19ce8e390c029857802cfe40190bb5ea/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala#L352)
[ColumnNodeToProtoConverter](https://github.com/apache/spark/blob/f80b919d19ce8e390c029857802cfe40190bb5ea/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala#L164)
Some on branch `LambdaVariable`, which in my view support the ideia exposing
of the lambda body on the `TreeNode` API:
[SessionCatalog](https://github.com/apache/spark/blob/620b2f65f0daed19e1800503df3c892b3bb664e1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala#L1641)
[HigherOrderFunction.functionForEval](https://github.com/apache/spark/blob/620b2f65f0daed19e1800503df3c892b3bb664e1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala#L221)
[NormalizePlan](https://github.com/apache/spark/blob/620b2f65f0daed19e1800503df3c892b3bb664e1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala#L107)
[TableOutputResolver](https://github.com/apache/spark/blob/620b2f65f0daed19e1800503df3c892b3bb664e1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala#L265)
Some branch on `HigherOrderFunction`:
[CheckAnalysis](https://github.com/apache/spark/blob/620b2f65f0daed19e1800503df3c892b3bb664e1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala#L355)
[Analyzer](https://github.com/apache/spark/blob/620b2f65f0daed19e1800503df3c892b3bb664e1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala#L2168)
[ResolveLambdaVariables](https://github.com/apache/spark/blob/620b2f65f0daed19e1800503df3c892b3bb664e1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala#L55)
Is true that some traversals branch on `HigherOrderFunction` expressions,
therefore supporting the ideia of a new `Expr::LambdaFunction`, but since
checking if a `Expr::ScalarFunction` invocation contains lambdas is simple as
the code below, and can be made even simpler by adding a helper method
`contains_lambdas(&self) -> bool` to `ScalarFunction`, I don't think it alone
justifies adding a new `Expr::LambdaFunction` holding a `ScalarFunction`. If
`LambdaUDF` trait is added then a new `Expr::LambdaFunction` variant has to be
added anyway so this become a non-issue.
```rust
match expr {
Expr::ScalarFunction(fun) if fun.args.iter().any(|arg| matches!(arg,
Expr::Lambda(_))) => ...,
// with helper
Expr::ScalarFunction(fun) if fun.contains_lambdas() => ...
...
}
impl ScalarFunction {
fn contains_lambdas(&self) -> bool {
self.args.iter().any(|arg| matches!(arg, Expr::Lambda(_)))
}
// Other helpers could be added to like:
fn lambda_args(&self) -> impl Iterator<Item = &LambdaFunction> {
self.args.iter().filter_map(|arg| match arg {
Expr::Lambda(l) => Some(l),
_ => None
})
}
fn non_lambda_args(&self) -> impl Iterator<Item = &Expr> {
self.args.iter().filter(|arg| !matches(arg, Expr::Lambda(_)))
}
}
```
[Higher order functions, like ArrayTransform, receive both the regular arg
as well as the lambda as regular Expressions, the lambda being expected to be a
LambdaFunction](https://github.com/apache/spark/blob/214bf958757c804e2ae637b3fc3d4ca9ec94a1d1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala#L312-L328):
```scala
case class ArrayTransform(
argument: Expression,
function: Expression)
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
override def dataType: ArrayType = ArrayType(function.dataType,
function.nullable)
override protected def bindInternal(
f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction):
ArrayTransform = {
val ArrayType(elementType, containsNull) = argument.dataType
function match {
case LambdaFunction(_, arguments, _) if arguments.size == 2 =>
copy(function = f(function, (elementType, containsNull) ::
(IntegerType, false) :: Nil))
case _ =>
copy(function = f(function, (elementType, containsNull) :: Nil))
}
}
```
[And it's instantiated like
that](https://github.com/apache/spark/blob/3298e685a9641c3e137917f75f58bc0bcfdfe32e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortExpression.scala#L152-L155):
```scala
val param = NamedLambdaVariable("x", et, containsNull)
val funcBody = insertMapSortRecursively(param)
ArrayTransform(e, LambdaFunction(funcBody, Seq(param)))
```
So, regarding to resolving lambdas/partition lambdas from args, all args are
owned by the implementation, and not by `HigherOrderFunction`, which is just a
trait, therefore the lambas are indeed resolved, but so are the other args, and
no partitioning occurs. I believe that our current approach would look like
this in Spark:
`trait Expression`, implemented by `class HigherOrderFunction` which owns a
`trait HigherOrderFunctionImpl` which is implemented by `ArrayTransform`, where
the lambda and the arg would be stored in the `class HigherOrderFunction`, not
in `ArrayTransform`. The Spark approach implemented in DF would be look the new
`LambdaFunction` expression and `LambdaUDF` trait cited above, where
implementation owns both the lambdas *and* the other args:
`Expr::LambdaFunction(LambdaFuntion::new(ArrayTransform::new(all_args_including_lambdas)))`,
instead of
`Expr::LambdaFunction(LambdaFunction::new(ArrayTransform::new(lambdas),
args_without_lambdas))`
Finally, running this script on `spark-shell` returns the following output,
showing both that lambda bodies are present in `TreeNode` traversals as well as
`LambdaFunction` nodes:
```scala
val df = spark.sql("SELECT transform(array(1), v -> v*2)")
val transform = df.queryExecution.logical.expressions(0)
transform.foreach { node =>
println(s"${node.nodeName}: ${node.toString}")
}
println(transform.treeString)
```
```
UnresolvedAlias: unresolvedalias(transform(array(1), lambdafunction((v * 2),
v)))
UnresolvedFunction: transform(array(1), lambdafunction((v * 2), v))
UnresolvedFunction: array(1)
Literal: 1
LambdaFunction: lambdafunction((v * 2), v)
Multiply: (v * 2)
UnresolvedNamedLambdaVariable: v
Literal: 2
UnresolvedNamedLambdaVariable: v
unresolvedalias('transform('array(1), lambdafunction((lambda 'v * 2), lambda
'v, false)))
+- 'transform('array(1), lambdafunction((lambda 'v * 2), lambda 'v, false))
:- 'array(1)
: +- 1
+- lambdafunction((lambda 'v * 2), lambda 'v, false)
:- (lambda 'v * 2)
: :- lambda 'v
: +- 2
+- lambda 'v
```
In short, if we more or less agree with my points above, adding a new
`LambdaUDF` just to not modify `ScalarFunctionArgs`, IMHO, doesn't seem worth
the additional code that get's to be reviewed and maintained both in this PR
(which is already big) and in subsequent PRs adding more lambdas UDFs like
`array_filter` and `array_fold`
What do you think? I'm missing or misunderstanding something? Thanks!
And again, sorry for the delay
cc @keen85
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]