peter-toth commented on code in PR #56070:
URL: https://github.com/apache/spark/pull/56070#discussion_r3303126758
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala:
##########
@@ -140,6 +140,19 @@ object SchemaPruning extends SQLConfHelper {
*/
private[catalyst] def getRootFields(expr: Expression): Seq[RootField] = {
expr match {
+ case ArrayTransform(argument, lambda: LambdaFunction) =>
Review Comment:
This case, and the matching `ArrayTransform` branch in
`ProjectionOverSchema`, are the only two extension points for lambda-aware
nested pruning, and they're hardcoded to `ArrayTransform`. The same shape
applies to several other higher-order functions whose element is *consumed*
rather than passed through to the output:
- `ArrayExists`, `ArrayForAll` — predicate over the element; output is
`Boolean`.
- `ArrayAggregate` — aggregation; output is the merge type.
- `MapFilter` — predicate over `(key, value)`; output type is the original
map.
For all of those, narrowing the input element struct based on what the
lambda body accesses is sound (no downstream consumer sees the original element
type). Each one currently requires duplicating both the `getRootFields` branch
and the `ProjectionOverSchema` branch, including the lambda-variable rewrite
mechanics.
A generalized refactor (could be a follow-up rather than blocking this PR):
- Lift `collectLambdaVariableFields`, `LambdaVariableField`, and the
per-element pruning into a helper that takes any expression with a
`LambdaFunction` child whose first argument binds an `array<struct<...>>` (or
`map<k, struct<...>>`) element.
- In `ProjectionOverSchema`, dispatch on `case h: HigherOrderFunction if
eligible(h)` instead of by class, and rewrite via the same
`ProjectionOverLambdaVariable` logic.
The counter-arguments are real but bounded:
- `ArrayFilter` / `ArraySort` / `ZipWith` pass the element through to the
output, so they can't reuse this without also tracking *output* consumers —
leave them out of the generalized set.
- The `ArraySort` case is bigger because the lambda has two element
variables; that argues for limiting v1 of the generalization to one-element
HOFs.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala:
##########
@@ -82,4 +107,37 @@ case class ProjectionOverSchema(schema: StructType, output:
AttributeSet) {
case _ =>
None
}
+
+ private object ArrayTypeProjection {
+ def unapply(expr: Expression): Option[StructType] = expr.dataType match {
+ case ArrayType(projectedElementSchema: StructType, _) =>
Some(projectedElementSchema)
+ case _ => None
+ }
+ }
+
+ /**
+ * Rewrites references rooted at one bound lambda element to use its
projected type and
+ * recomputes nested field ordinals against each projected struct in the
access path.
+ */
+ private case class ProjectionOverLambdaVariable(
+ original: NamedLambdaVariable,
+ projected: NamedLambdaVariable) {
+ def unapply(expr: Expression): Option[Expression] = project(expr)
+
+ private def project(expr: Expression): Option[Expression] = expr match {
+ case variable: NamedLambdaVariable if variable.semanticEquals(original)
=>
+ Some(projected)
+ case GetStructFieldObject(child, field: StructField) =>
+ project(child).map { projection =>
+ projection.dataType match {
+ case projectedSchema: StructType =>
+ GetStructField(projection,
projectedSchema.fieldIndex(field.name))
+ case dataType =>
+ throw new IllegalStateException(
Review Comment:
The other two "shouldn't happen" branches in this file (lines 61, 76 in the
post-PR file — `GetArrayStructFields` and `GetStructFieldObject` mismatches)
throw `SparkException.internalError`. Suggest matching:
```suggestion
throw SparkException.internalError(
s"unmatched lambda child schema for GetStructField:
${dataType.toString}")
```
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala:
##########
@@ -162,6 +175,80 @@ object SchemaPruning extends SQLConfHelper {
}
}
+ private def getArrayTransformRootField(
+ argument: Expression,
+ function: Expression,
+ elementVar: NamedLambdaVariable): Option[StructField] = {
+ argument.dataType match {
+ case ArrayType(_: StructType, containsNull) =>
+ val selectedFields = collectLambdaVariableFields(function, elementVar)
+ if (selectedFields.exists(_.nonEmpty)) {
+ val mergedElementSchema = selectedFields
+ .get
+ .map(field => StructType(Array(field)))
+ .reduceLeft(_ merge _)
+ SelectedField.withDataType(
+ argument,
+ ArrayType(mergedElementSchema, containsNull))
+ } else {
+ None
+ }
+ case _ => None
+ }
+ }
+
+ /**
+ * Collects statically identifiable nested fields read from `elementVar`.
+ *
+ * `Some(Seq.empty)` means this subtree does not reference the element
variable, and
+ * `Some(fields)` means every reference can be satisfied by the listed
nested fields. `None`
+ * means the full element is required somewhere (for example, `x =>
struct(x.a, x)`), so it is
+ * not safe to prune the element struct.
+ */
+ private def collectLambdaVariableFields(
+ expr: Expression,
+ elementVar: NamedLambdaVariable): Option[Seq[StructField]] = {
+ expr match {
+ case LambdaVariableField(field, variable) if
variable.semanticEquals(elementVar) =>
+ Some(field :: Nil)
+ case variable: NamedLambdaVariable if
variable.semanticEquals(elementVar) =>
+ None
+ case _ =>
+ expr.children.foldLeft(Option(Seq.empty[StructField])) {
+ case (Some(fields), child) =>
+ collectLambdaVariableFields(child, elementVar).map(fields ++ _)
+ case (None, _) => None
+ }
+ }
+ }
+
+ /**
+ * Converts a field access rooted at the lambda element into the single
nested
+ * [[StructField]] shape needed by the input array schema. For example,
+ * `x.company.address` becomes `company: struct<address: ...>`.
+ */
+ private object LambdaVariableField {
Review Comment:
The chain walker here only handles `GetStructField`, but `SelectedField`'s
`selectField` (which this is parallel to) handles `GetArrayStructFields`,
`GetArrayItem`, `ElementAt`, `GetMapValue`, `MapKeys`, and `MapValues` as well.
As written, queries like
```sql
SELECT transform(arr, x -> x.subArr[0].field) FROM t
SELECT transform(arr, x -> element_at(x.subArr, 1).field) FROM t
SELECT transform(arr, x -> x.mapField['k'].field) FROM t
```
will read the full inner type for `subArr` / `mapField` because the chain is
broken at the first non-`GetStructField` node, and
`collectLambdaVariableFields` then falls through to `_` and returns the
lambda-var leaf.
As an improvement, `LambdaVariableField` could mirror `SelectedField`'s case
set (with `selectField` taking the lambda-variable leaf instead of `Attribute`
as terminator). The two extractors would then differ only in the leaf case and
could share a parameterized helper. `array<struct>` columns nested under
another `array<struct>` are common, so closing this gap covers a non-trivial
set of real queries.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]