peter-toth commented on code in PR #56070:
URL: https://github.com/apache/spark/pull/56070#discussion_r3303126793
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala:
##########
@@ -162,6 +175,80 @@ object SchemaPruning extends SQLConfHelper {
}
}
+ private def getArrayTransformRootField(
+ argument: Expression,
+ function: Expression,
+ elementVar: NamedLambdaVariable): Option[StructField] = {
+ argument.dataType match {
+ case ArrayType(_: StructType, containsNull) =>
+ val selectedFields = collectLambdaVariableFields(function, elementVar)
+ if (selectedFields.exists(_.nonEmpty)) {
+ val mergedElementSchema = selectedFields
+ .get
+ .map(field => StructType(Array(field)))
+ .reduceLeft(_ merge _)
+ SelectedField.withDataType(
+ argument,
+ ArrayType(mergedElementSchema, containsNull))
+ } else {
+ None
+ }
+ case _ => None
+ }
+ }
+
+ /**
+ * Collects statically identifiable nested fields read from `elementVar`.
+ *
+ * `Some(Seq.empty)` means this subtree does not reference the element
variable, and
+ * `Some(fields)` means every reference can be satisfied by the listed
nested fields. `None`
+ * means the full element is required somewhere (for example, `x =>
struct(x.a, x)`), so it is
+ * not safe to prune the element struct.
+ */
+ private def collectLambdaVariableFields(
+ expr: Expression,
+ elementVar: NamedLambdaVariable): Option[Seq[StructField]] = {
+ expr match {
+ case LambdaVariableField(field, variable) if
variable.semanticEquals(elementVar) =>
+ Some(field :: Nil)
+ case variable: NamedLambdaVariable if
variable.semanticEquals(elementVar) =>
+ None
+ case _ =>
+ expr.children.foldLeft(Option(Seq.empty[StructField])) {
+ case (Some(fields), child) =>
+ collectLambdaVariableFields(child, elementVar).map(fields ++ _)
+ case (None, _) => None
+ }
+ }
+ }
+
+ /**
+ * Converts a field access rooted at the lambda element into the single
nested
+ * [[StructField]] shape needed by the input array schema. For example,
+ * `x.company.address` becomes `company: struct<address: ...>`.
+ */
+ private object LambdaVariableField {
Review Comment:
The chain walker here only handles `GetStructField`, but `SelectedField`'s
`selectField` (which this is parallel to) handles `GetArrayStructFields`,
`GetArrayItem`, `ElementAt`, `GetMapValue`, `MapKeys`, and `MapValues` as well.
As written, queries like
```sql
SELECT transform(arr, x -> x.subArr[0].field) FROM t
SELECT transform(arr, x -> element_at(x.subArr, 1).field) FROM t
SELECT transform(arr, x -> x.mapField['k'].field) FROM t
```
will read the full inner type for `subArr` / `mapField` because the chain is
broken at the first non-`GetStructField` node, and
`collectLambdaVariableFields` then falls through to `_` and returns the
lambda-var leaf.
As an follow-up improvement PR, `LambdaVariableField` could mirror
`SelectedField`'s case set (with `selectField` taking the lambda-variable leaf
instead of `Attribute` as terminator). The two extractors would then differ
only in the leaf case and could share a parameterized helper. `array<struct>`
columns nested under another `array<struct>` are common, so closing this gap
covers a non-trivial set of real queries.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]