peter-toth commented on code in PR #56070:
URL: https://github.com/apache/spark/pull/56070#discussion_r3303126793


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala:
##########
@@ -162,6 +175,80 @@ object SchemaPruning extends SQLConfHelper {
     }
   }
 
+  private def getArrayTransformRootField(
+      argument: Expression,
+      function: Expression,
+      elementVar: NamedLambdaVariable): Option[StructField] = {
+    argument.dataType match {
+      case ArrayType(_: StructType, containsNull) =>
+        val selectedFields = collectLambdaVariableFields(function, elementVar)
+        if (selectedFields.exists(_.nonEmpty)) {
+          val mergedElementSchema = selectedFields
+            .get
+            .map(field => StructType(Array(field)))
+            .reduceLeft(_ merge _)
+          SelectedField.withDataType(
+            argument,
+            ArrayType(mergedElementSchema, containsNull))
+        } else {
+          None
+        }
+      case _ => None
+    }
+  }
+
+  /**
+   * Collects statically identifiable nested fields read from `elementVar`.
+   *
+   * `Some(Seq.empty)` means this subtree does not reference the element 
variable, and
+   * `Some(fields)` means every reference can be satisfied by the listed 
nested fields. `None`
+   * means the full element is required somewhere (for example, `x => 
struct(x.a, x)`), so it is
+   * not safe to prune the element struct.
+   */
+  private def collectLambdaVariableFields(
+      expr: Expression,
+      elementVar: NamedLambdaVariable): Option[Seq[StructField]] = {
+    expr match {
+      case LambdaVariableField(field, variable) if 
variable.semanticEquals(elementVar) =>
+        Some(field :: Nil)
+      case variable: NamedLambdaVariable if 
variable.semanticEquals(elementVar) =>
+        None
+      case _ =>
+        expr.children.foldLeft(Option(Seq.empty[StructField])) {
+          case (Some(fields), child) =>
+            collectLambdaVariableFields(child, elementVar).map(fields ++ _)
+          case (None, _) => None
+        }
+    }
+  }
+
+  /**
+   * Converts a field access rooted at the lambda element into the single 
nested
+   * [[StructField]] shape needed by the input array schema. For example,
+   * `x.company.address` becomes `company: struct<address: ...>`.
+   */
+  private object LambdaVariableField {

Review Comment:
   The chain walker here only handles `GetStructField`, but `SelectedField`'s 
`selectField` (which this is parallel to) handles `GetArrayStructFields`, 
`GetArrayItem`, `ElementAt`, `GetMapValue`, `MapKeys`, and `MapValues` as well. 
As written, queries like
   
   ```sql
   SELECT transform(arr, x -> x.subArr[0].field) FROM t
   SELECT transform(arr, x -> element_at(x.subArr, 1).field) FROM t
   SELECT transform(arr, x -> x.mapField['k'].field) FROM t
   ```
   
   will read the full inner type for `subArr` / `mapField` because the chain is 
broken at the first non-`GetStructField` node, and 
`collectLambdaVariableFields` then falls through to `_` and returns the 
lambda-var leaf.
   
   As a follow-up improvement PR, `LambdaVariableField` could mirror 
`SelectedField`'s case set (with `selectField` taking the lambda-variable leaf 
instead of `Attribute` as terminator). The two extractors would then differ 
only in the leaf case and could share a parameterized helper. `array<struct>` 
columns nested under another `array<struct>` are common, so closing this gap 
covers a non-trivial set of real queries.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to