This is an automated email from the ASF dual-hosted git repository.
sunchao pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new eccb3b696888 [SPARK-57022][SQL] Support nested column pruning for
transform over arrays of structs
eccb3b696888 is described below
commit eccb3b696888557ff4d552d35eed55e7db5e3ba6
Author: Chao Sun <[email protected]>
AuthorDate: Tue May 26 14:26:18 2026 -0700
[SPARK-57022][SQL] Support nested column pruning for transform over arrays
of structs
### Why are the changes needed?
Spark can prune nested struct fields referenced directly by a query, but it
does not currently prune nested fields read through the lambda variable of
`transform` over an `array<struct>` column.
For example:
```sql
SELECT transform(rule_results, rule ->
named_struct(
'rule_public_id', rule.rule_public_id,
'rule_version', rule.rule_version))
FROM events
```
If `rule_results` contains additional fields, Spark currently retains the
full element struct in the scan schema even though only two nested fields are
required. This causes unnecessary Parquet and ORC input reads for wide array
element schemas.
This change addresses
[SPARK-57022](https://issues.apache.org/jira/browse/SPARK-57022).
### What changes were proposed in this pull request?
- Recognize statically identifiable nested field reads through the element
variable of `ArrayTransform`.
- Build a projected array element schema from exactly those referenced
fields and propagate it to the scan input.
- Rewrite the bound lambda variable type and `GetStructField` ordinals
against the projected element schema after pruning.
- Fall back to retaining the full element schema when the lambda consumes
the complete element, so pruning is applied only when it is safe.
- Add Catalyst and datasource tests covering ordinal rewrites, deep
nesting, nested input paths with null values, indexed lambdas, case-insensitive
resolution, and conservative fallback.
The implementation intentionally has two stages. `SchemaPruning` discovers
which fields the lambda needs from the array element. `ProjectionOverSchema`
then rewrites the lambda against the narrower element type because pruning can
change field ordinals. For example, pruning `struct<a, b, c>` to `struct<a, c>`
moves `c` from ordinal `2` to ordinal `1`.
### Does this PR introduce _any_ user-facing change?
Yes. Eligible queries using `transform` over arrays of structs can read a
narrower input schema. Query results and SQL APIs are unchanged.
### How was this patch tested?
- `build/sbt "catalyst/testOnly
org.apache.spark.sql.catalyst.expressions.SchemaPruningSuite"`
- `build/sbt "sql/testOnly
org.apache.spark.sql.execution.datasources.parquet.ParquetV1SchemaPruningSuite
org.apache.spark.sql.execution.datasources.parquet.ParquetV2SchemaPruningSuite
org.apache.spark.sql.execution.datasources.orc.OrcV1SchemaPruningSuite
org.apache.spark.sql.execution.datasources.orc.OrcV2SchemaPruningSuite -- -z
ArrayTransform"`
- `git diff --check apache/master...HEAD`
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Codex (GPT-5)
Closes #56070 from
sunchao/dev/chao/codex/spark-57022-array-transform-pruning.
Authored-by: Chao Sun <[email protected]>
Signed-off-by: Chao Sun <[email protected]>
(cherry picked from commit e3e81dcfd59fd26d2110fff5bc12fd232718ced5)
Signed-off-by: Chao Sun <[email protected]>
---
.../expressions/ProjectionOverSchema.scala | 60 ++++++++++++
.../sql/catalyst/expressions/SchemaPruning.scala | 91 +++++++++++++++++
.../sql/catalyst/expressions/SelectedField.scala | 15 +++
.../catalyst/expressions/SchemaPruningSuite.scala | 42 ++++++++
.../execution/datasources/SchemaPruningSuite.scala | 109 +++++++++++++++++++++
5 files changed, 317 insertions(+)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
index bb67c173b946..362643016d83 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
@@ -68,6 +68,31 @@ case class ProjectionOverSchema(schema: StructType, output:
AttributeSet) {
getProjection(child).map { projection => MapValues(projection) }
case GetMapValue(child, key) =>
getProjection(child).map { projection => GetMapValue(projection, key) }
+ case transform @ ArrayTransform(argument, lambda: LambdaFunction) =>
+ getProjection(argument).map {
+ case projection @ ArrayTypeProjection(projectedElementSchema) =>
+ lambda.arguments.headOption match {
+ case Some(elementVar: NamedLambdaVariable) =>
+ // Pruning fields changes the physical ordinal layout of the
element struct.
+ // For example, pruning struct<a, b, c> to struct<a, c> moves
c from ordinal 2
+ // to ordinal 1, so rewrite both the variable type and its
field accesses.
+ val projectedElementVar = elementVar.copy(dataType =
projectedElementSchema)
+ val lambdaProjection =
+ ProjectionOverLambdaVariable(elementVar, projectedElementVar)
+ val projectedBody = lambda.function.transformDown {
+ case lambdaProjection(expr) => expr
+ }
+ transform.copy(
+ argument = projection,
+ function = lambda.copy(
+ function = projectedBody,
+ arguments = projectedElementVar +: lambda.arguments.tail))
+ case _ =>
+ transform.copy(argument = projection)
+ }
+ case projection =>
+ transform.copy(argument = projection)
+ }
case GetStructFieldObject(child, field: StructField) =>
getProjection(child).map(p => (p, p.dataType)).map {
case (projection, projSchema: StructType) =>
@@ -82,4 +107,39 @@ case class ProjectionOverSchema(schema: StructType, output:
AttributeSet) {
case _ =>
None
}
+
+ private object ArrayTypeProjection {
+ def unapply(expr: Expression): Option[StructType] = expr.dataType match {
+ case ArrayType(projectedElementSchema: StructType, _) =>
Some(projectedElementSchema)
+ case _ => None
+ }
+ }
+
+ /**
+ * Rewrites references rooted at one bound lambda element to use its
projected type and
+ * recomputes nested field ordinals against each projected struct in the
access path.
+ * This must support the same access paths collected by `SchemaPruning` for
lambda variables;
+ * currently both sides support only `GetStructField` chains.
+ */
+ private case class ProjectionOverLambdaVariable(
+ original: NamedLambdaVariable,
+ projected: NamedLambdaVariable) {
+ def unapply(expr: Expression): Option[Expression] = project(expr)
+
+ private def project(expr: Expression): Option[Expression] = expr match {
+ case variable: NamedLambdaVariable if variable.semanticEquals(original)
=>
+ Some(projected)
+ case GetStructFieldObject(child, field: StructField) =>
+ project(child).map { projection =>
+ projection.dataType match {
+ case projectedSchema: StructType =>
+ GetStructField(projection,
projectedSchema.fieldIndex(field.name))
+ case dataType =>
+ throw SparkException.internalError(
+ s"unmatched lambda child schema for GetStructField:
${dataType.toString}")
+ }
+ }
+ case _ => None
+ }
+ }
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
index dd2d6c2cb610..2f99dd54f77a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
@@ -140,6 +140,19 @@ object SchemaPruning extends SQLConfHelper {
*/
private[catalyst] def getRootFields(expr: Expression): Seq[RootField] = {
expr match {
+ case ArrayTransform(argument, lambda: LambdaFunction) =>
+ // Field accesses through the lambda variable are not directly rooted
at the input
+ // attribute. Convert them into a projected type for the transform
argument so that
+ // physical nested column pruning can see them.
+ val nestedRootFields = lambda.arguments.headOption.collect {
+ case elementVar: NamedLambdaVariable =>
+ getArrayTransformRootField(argument, lambda.function, elementVar)
+ }.flatten.toSeq.map(field => RootField(field, derivedFromAtt = false))
+ if (nestedRootFields.nonEmpty) {
+ nestedRootFields ++ getRootFields(lambda.function)
+ } else {
+ expr.children.flatMap(getRootFields)
+ }
case att: Attribute =>
RootField(StructField(att.name, att.dataType, att.nullable,
att.metadata),
derivedFromAtt = true) :: Nil
@@ -162,6 +175,84 @@ object SchemaPruning extends SQLConfHelper {
}
}
+ private def getArrayTransformRootField(
+ argument: Expression,
+ function: Expression,
+ elementVar: NamedLambdaVariable): Option[StructField] = {
+ argument.dataType match {
+ case ArrayType(_: StructType, containsNull) =>
+ val selectedFields = collectLambdaVariableFields(function, elementVar)
+ if (selectedFields.exists(_.nonEmpty)) {
+ val mergedElementSchema = selectedFields
+ .get
+ .map(field => StructType(Array(field)))
+ .reduceLeft(_ merge _)
+ SelectedField.withDataType(
+ argument,
+ ArrayType(mergedElementSchema, containsNull))
+ } else {
+ None
+ }
+ case _ => None
+ }
+ }
+
+ /**
+ * Collects statically identifiable nested fields read from `elementVar`.
+ *
+ * `Some(Seq.empty)` means this subtree does not reference the element
variable, and
+ * `Some(fields)` means every reference can be satisfied by the listed
nested fields. `None`
+ * means the full element is required somewhere (for example, `x =>
struct(x.a, x)`), so it is
+ * not safe to prune the element struct.
+ *
+ * Currently only `GetStructField` chains rooted at `elementVar` are
collected; array or map
+ * traversal within the lambda conservatively requires the full element.
Keep this set of
+ * supported paths in sync with `ProjectionOverLambdaVariable` in
`ProjectionOverSchema`.
+ */
+ private def collectLambdaVariableFields(
+ expr: Expression,
+ elementVar: NamedLambdaVariable): Option[Seq[StructField]] = {
+ expr match {
+ case LambdaVariableField(field, variable) if
variable.semanticEquals(elementVar) =>
+ Some(field :: Nil)
+ case variable: NamedLambdaVariable if
variable.semanticEquals(elementVar) =>
+ None
+ case _ =>
+ expr.children.foldLeft(Option(Seq.empty[StructField])) {
+ case (Some(fields), child) =>
+ collectLambdaVariableFields(child, elementVar).map(fields ++ _)
+ case (None, _) => None
+ }
+ }
+ }
+
+ /**
+ * Converts a field access rooted at the lambda element into the single
nested
+ * [[StructField]] shape needed by the input array schema. For example,
+ * `x.company.address` becomes `company: struct<address: ...>`.
+ */
+ private object LambdaVariableField {
+ def unapply(expr: Expression): Option[(StructField, NamedLambdaVariable)]
= {
+ def selectField(
+ expression: Expression,
+ dataTypeOpt: Option[DataType]): Option[(StructField,
NamedLambdaVariable)] =
+ expression match {
+ case variable: NamedLambdaVariable =>
+ dataTypeOpt.collect {
+ case schema: StructType if schema.length == 1 =>
+ schema.head -> variable
+ }
+ case getStructField: GetStructField =>
+ val field = getStructField.childSchema(getStructField.ordinal)
+ val newField = field.copy(dataType =
dataTypeOpt.getOrElse(field.dataType))
+ selectField(getStructField.child, Some(StructType(Array(newField))))
+ case _ => None
+ }
+
+ selectField(expr, None)
+ }
+ }
+
/**
* This represents a "root" schema field (aka top-level, no-parent). `field`
is the
* `StructField` for field name and datatype. `derivedFromAtt` indicates
whether it
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
index 820dc452d7e8..e36224e7d5c1 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
@@ -62,6 +62,21 @@ object SelectedField {
selectField(unaliased, None)
}
+ /**
+ * Builds the selected root field for `expr` while substituting a narrower
projected data type.
+ * This lets a lambda field access establish the required scan type of its
array argument, even
+ * though the lambda variable is not itself rooted at a data source
attribute.
+ */
+ private[catalyst] def withDataType(
+ expr: Expression,
+ dataType: DataType): Option[StructField] = {
+ val unaliased = expr match {
+ case Alias(child, _) => child
+ case expression => expression
+ }
+ selectField(unaliased, Some(dataType))
+ }
+
/**
* Convert an expression into the parts of the schema (the field) it
accesses.
*/
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
index a968526a89f1..9426ef91349e 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
@@ -144,4 +144,46 @@ class SchemaPruningSuite extends SparkFunSuite with
SQLHelper {
assert(prunedSchema.head.metadata.getString("foo") == "bar")
}
+ test("collect nested fields used by ArrayTransform lambda") {
+ val elementType = StructType.fromDDL("a int, b int, c int")
+ val eventType = StructType(Seq(
+ StructField("rules", ArrayType(elementType, containsNull = true))))
+ val event = AttributeReference("event", eventType)()
+ val element = NamedLambdaVariable("x", elementType, nullable = true)
+ val transformed = ArrayTransform(
+ GetStructField(event, 0, Some("rules")),
+ LambdaFunction(
+ CreateNamedStruct(Seq(
+ Literal("a"),
+ GetStructField(element, 0, Some("a")),
+ Literal("c"),
+ GetStructField(element, 2, Some("c")))),
+ Seq(element)))
+
+ val rootFields = SchemaPruning.getRootFields(transformed)
+ val prunedSchema = SchemaPruning.pruneSchema(
+ StructType(Seq(StructField("event", eventType))),
+ rootFields)
+
+ assert(prunedSchema === StructType.fromDDL(
+ "event struct<rules:array<struct<a:int,c:int>>>"))
+ }
+
+ test("do not collect ArrayTransform lambda fields when the whole element is
used") {
+ val elementType = StructType.fromDDL("a int, b int")
+ val eventType = StructType(Seq(
+ StructField("rules", ArrayType(elementType, containsNull = true))))
+ val event = AttributeReference("event", eventType)()
+ val element = NamedLambdaVariable("x", elementType, nullable = true)
+ val transformed = ArrayTransform(
+ GetStructField(event, 0, Some("rules")),
+ LambdaFunction(element, Seq(element)))
+
+ val rootFields = SchemaPruning.getRootFields(transformed)
+
+ assert(rootFields === Seq(
+ SchemaPruning.RootField(
+ StructField("event", eventType, nullable = true),
+ derivedFromAtt = false)))
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index fd8d1308e990..5213c0c5f4e2 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -62,6 +62,8 @@ abstract class SchemaPruningSuite
super.sparkConf.set(SQLConf.ANSI_ENABLED.key, "false")
case class Employee(id: Int, name: FullName, employer: Company)
+ case class Team(members: Array[Employer])
+ case class Organization(team: Team)
val janeDoe = FullName("Jane", "X.", "Doe")
val johnDoe = FullName("John", "Y.", "Doe")
@@ -86,6 +88,9 @@ abstract class SchemaPruningSuite
Department(2, "Operation", 4, employerWithNullCompany2) :: Nil
val employees = Employee(0, janeDoe, company) :: Employee(1, johnDoe,
company) :: Nil
+ val teams = Team(Array(employer)) :: Nil
+ val organizations =
+ Organization(Team(Array[Employer](null, employerWithNullCompany,
employer))) :: Nil
case class Name(first: String, last: String)
case class BriefContact(id: Int, name: Name, address: String)
@@ -361,6 +366,110 @@ abstract class SchemaPruningSuite
}
}
+ testSchemaPruning("select ArrayTransform over nested fields of array of
struct") {
+ val query = spark.table("contacts")
+ .where("p = 1")
+ .select(transform(col("friends"), friend =>
+ struct(
+ friend.getField("first").as("first"),
+ friend.getField("last").as("last"))))
+
+ checkScan(query, "struct<friends:array<struct<first:string,last:string>>>")
+ checkAnswer(query,
+ Row(Array(Row("Susan", "Smith"))) ::
+ Row(Array.empty[Row]) ::
+ Nil)
+ }
+
+ testSchemaPruning("select ArrayTransform over deep nested fields of array of
struct") {
+ withDataSourceTable(teams, "teams") {
+ val query = spark.table("teams")
+ .select(transform(col("members"), member =>
+ member.getField("company").getField("address")))
+
+ checkScan(query,
"struct<members:array<struct<company:struct<address:string>>>>")
+ checkAnswer(query, Row(Array("123 Business Street")) :: Nil)
+ }
+ }
+
+ testSchemaPruning("select ArrayTransform merging nested parent and child
fields") {
+ withDataSourceTable(teams, "teams") {
+ val query = spark.table("teams")
+ .select(transform(col("members"), member =>
+ struct(
+ member.getField("company").as("company"),
+ member.getField("company").getField("address").as("address"))))
+
+ checkScan(query,
"struct<members:array<struct<company:struct<name:string,address:string>>>>")
+ checkAnswer(query,
+ Row(Array(Row(Row("abc", "123 Business Street"), "123 Business
Street"))) :: Nil)
+ }
+ }
+
+ testSchemaPruning("select ArrayTransform over nested array path with null
elements") {
+ withDataSourceTable(organizations, "organizations") {
+ val query = spark.table("organizations")
+ .select(transform(col("team.members"), member =>
+ member.getField("company").getField("address")))
+
+ checkScan(query,
+
"struct<team:struct<members:array<struct<company:struct<address:string>>>>>")
+ checkAnswer(query, Row(Array(null, null, "123 Business Street")) :: Nil)
+ }
+ }
+
+ testSchemaPruning("select indexed ArrayTransform over nested fields of array
of struct") {
+ val query = spark.table("contacts")
+ .where("p = 1")
+ .select(transform(col("friends"), (friend, index) =>
+ struct(friend.getField("last").as("last"), index.as("index"))))
+
+ checkScan(query, "struct<friends:array<struct<last:string>>>")
+ checkAnswer(query,
+ Row(Array(Row("Smith", 0))) ::
+ Row(Array.empty[Row]) ::
+ Nil)
+ }
+
+ testSchemaPruning("select case-insensitive ArrayTransform nested field") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+ val query = spark.table("contacts")
+ .where("p = 1")
+ .select(transform(col("friends"), friend => friend.getField("LaSt")))
+
+ checkScan(query, "struct<friends:array<struct<last:string>>>")
+ checkAnswer(query,
+ Row(Array("Smith")) ::
+ Row(Array.empty[String]) ::
+ Nil)
+ }
+ }
+
+ testSchemaPruning("do not prune ArrayTransform when the whole element is
used") {
+ val query = spark.table("contacts")
+ .where("p = 1")
+ .select(transform(col("friends"), friend => friend))
+
+ checkScan(query,
"struct<friends:array<struct<first:string,middle:string,last:string>>>")
+ checkAnswer(query,
+ Row(Array(Row("Susan", "Z.", "Smith"))) ::
+ Row(Array.empty[Row]) ::
+ Nil)
+ }
+
+ testSchemaPruning("do not prune ArrayTransform when a nested field and whole
element are used") {
+ val query = spark.table("contacts")
+ .where("p = 1")
+ .select(transform(col("friends"), friend =>
+ struct(friend.getField("first").as("first"), friend.as("friend"))))
+
+ checkScan(query,
"struct<friends:array<struct<first:string,middle:string,last:string>>>")
+ checkAnswer(query,
+ Row(Array(Row("Susan", Row("Susan", "Z.", "Smith")))) ::
+ Row(Array.empty[Row]) ::
+ Nil)
+ }
+
testSchemaPruning("SPARK-34638: nested column prune on generator output") {
val query1 = spark.table("contacts")
.select(explode(col("friends")).as("friend"))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]