peter-toth commented on code in PR #56070:
URL: https://github.com/apache/spark/pull/56070#discussion_r3303126758


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala:
##########
@@ -140,6 +140,19 @@ object SchemaPruning extends SQLConfHelper {
    */
   private[catalyst] def getRootFields(expr: Expression): Seq[RootField] = {
     expr match {
+      case ArrayTransform(argument, lambda: LambdaFunction) =>

Review Comment:
   This case, and the matching `ArrayTransform` branch in 
`ProjectionOverSchema`, are the only two extension points for lambda-aware 
nested pruning, and they're hardcoded to `ArrayTransform`. The same shape 
applies to several other higher-order functions whose element is *consumed* 
rather than passed through to the output:
   
   - `ArrayExists`, `ArrayForAll` — predicate over the element; output is 
`Boolean`.
   - `ArrayAggregate` — aggregation; output is the merge type.
   - `MapFilter` — predicate over `(key, value)`; output type is the original 
map.
   
   For all of those, narrowing the input element struct based on what the 
lambda body accesses is sound (no downstream consumer sees the original element 
type). Each one currently requires duplicating both the `getRootFields` branch 
and the `ProjectionOverSchema` branch, including the lambda-variable rewrite 
mechanics.
   
   A generalized refactor (could be a follow-up rather than blocking this PR):
   
   - Lift `collectLambdaVariableFields`, `LambdaVariableField`, and the 
per-element pruning into a helper that takes any expression with a 
`LambdaFunction` child whose first argument binds an `array<struct<...>>` (or 
`map<k, struct<...>>`) element.
   - In `ProjectionOverSchema`, dispatch on `case h: HigherOrderFunction if 
eligible(h)` instead of by class, and rewrite via the same 
`ProjectionOverLambdaVariable` logic.
   
   The counter-arguments are real but bounded:
   - `ArrayFilter` / `ArraySort` / `ZipWith` pass the element through to the 
output, so they can't reuse this without also tracking *output* consumers — 
leave them out of the generalized set.
   - The `ArraySort` case is bigger because the lambda has two element 
variables; that argues for limiting v1 of the generalization to one-element 
HOFs.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala:
##########
@@ -82,4 +107,37 @@ case class ProjectionOverSchema(schema: StructType, output: 
AttributeSet) {
       case _ =>
         None
     }
+
+  private object ArrayTypeProjection {
+    def unapply(expr: Expression): Option[StructType] = expr.dataType match {
+      case ArrayType(projectedElementSchema: StructType, _) => 
Some(projectedElementSchema)
+      case _ => None
+    }
+  }
+
+  /**
+   * Rewrites references rooted at one bound lambda element to use its 
projected type and
+   * recomputes nested field ordinals against each projected struct in the 
access path.
+   */
+  private case class ProjectionOverLambdaVariable(
+      original: NamedLambdaVariable,
+      projected: NamedLambdaVariable) {
+    def unapply(expr: Expression): Option[Expression] = project(expr)
+
+    private def project(expr: Expression): Option[Expression] = expr match {
+      case variable: NamedLambdaVariable if variable.semanticEquals(original) 
=>
+        Some(projected)
+      case GetStructFieldObject(child, field: StructField) =>
+        project(child).map { projection =>
+          projection.dataType match {
+            case projectedSchema: StructType =>
+              GetStructField(projection, 
projectedSchema.fieldIndex(field.name))
+            case dataType =>
+              throw new IllegalStateException(

Review Comment:
   The other two "shouldn't happen" branches in this file (lines 61, 76 in the 
post-PR file — `GetArrayStructFields` and `GetStructFieldObject` mismatches) 
throw `SparkException.internalError`. Suggest matching:
   
   ```suggestion
                 throw SparkException.internalError(
                   s"unmatched lambda child schema for GetStructField: 
${dataType.toString}")
   ```



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala:
##########
@@ -162,6 +175,80 @@ object SchemaPruning extends SQLConfHelper {
     }
   }
 
+  private def getArrayTransformRootField(
+      argument: Expression,
+      function: Expression,
+      elementVar: NamedLambdaVariable): Option[StructField] = {
+    argument.dataType match {
+      case ArrayType(_: StructType, containsNull) =>
+        val selectedFields = collectLambdaVariableFields(function, elementVar)
+        if (selectedFields.exists(_.nonEmpty)) {
+          val mergedElementSchema = selectedFields
+            .get
+            .map(field => StructType(Array(field)))
+            .reduceLeft(_ merge _)
+          SelectedField.withDataType(
+            argument,
+            ArrayType(mergedElementSchema, containsNull))
+        } else {
+          None
+        }
+      case _ => None
+    }
+  }
+
+  /**
+   * Collects statically identifiable nested fields read from `elementVar`.
+   *
+   * `Some(Seq.empty)` means this subtree does not reference the element 
variable, and
+   * `Some(fields)` means every reference can be satisfied by the listed 
nested fields. `None`
+   * means the full element is required somewhere (for example, `x => 
struct(x.a, x)`), so it is
+   * not safe to prune the element struct.
+   */
+  private def collectLambdaVariableFields(
+      expr: Expression,
+      elementVar: NamedLambdaVariable): Option[Seq[StructField]] = {
+    expr match {
+      case LambdaVariableField(field, variable) if 
variable.semanticEquals(elementVar) =>
+        Some(field :: Nil)
+      case variable: NamedLambdaVariable if 
variable.semanticEquals(elementVar) =>
+        None
+      case _ =>
+        expr.children.foldLeft(Option(Seq.empty[StructField])) {
+          case (Some(fields), child) =>
+            collectLambdaVariableFields(child, elementVar).map(fields ++ _)
+          case (None, _) => None
+        }
+    }
+  }
+
+  /**
+   * Converts a field access rooted at the lambda element into the single 
nested
+   * [[StructField]] shape needed by the input array schema. For example,
+   * `x.company.address` becomes `company: struct<address: ...>`.
+   */
+  private object LambdaVariableField {

Review Comment:
   The chain walker here only handles `GetStructField`, but `SelectedField`'s 
`selectField` (which this is parallel to) handles `GetArrayStructFields`, 
`GetArrayItem`, `ElementAt`, `GetMapValue`, `MapKeys`, and `MapValues` as well. 
As written, queries like
   
   ```sql
   SELECT transform(arr, x -> x.subArr[0].field) FROM t
   SELECT transform(arr, x -> element_at(x.subArr, 1).field) FROM t
   SELECT transform(arr, x -> x.mapField['k'].field) FROM t
   ```
   
   will read the full inner type for `subArr` / `mapField` because the chain is 
broken at the first non-`GetStructField` node, and 
`collectLambdaVariableFields` then falls through to `_` and returns the 
lambda-var leaf.
   
   As an improvement, `LambdaVariableField` could mirror `SelectedField`'s case 
set (with `selectField` taking the lambda-variable leaf instead of `Attribute` 
as terminator). The two extractors would then differ only in the leaf case and 
could share a parameterized helper. `array<struct>` columns nested under 
another `array<struct>` are common, so closing this gap covers a non-trivial 
set of real queries.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to