This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e3e81dcfd59f [SPARK-57022][SQL] Support nested column pruning for 
transform over arrays of structs
e3e81dcfd59f is described below

commit e3e81dcfd59fd26d2110fff5bc12fd232718ced5
Author: Chao Sun <[email protected]>
AuthorDate: Tue May 26 14:26:18 2026 -0700

    [SPARK-57022][SQL] Support nested column pruning for transform over arrays 
of structs
    
    ### Why are the changes needed?
    
    Spark can prune nested struct fields referenced directly by a query, but it 
does not currently prune nested fields read through the lambda variable of 
`transform` over an `array<struct>` column.
    
    For example:
    
    ```sql
    SELECT transform(rule_results, rule ->
      named_struct(
        'rule_public_id', rule.rule_public_id,
        'rule_version', rule.rule_version))
    FROM events
    ```
    
    If `rule_results` contains additional fields, Spark currently retains the 
full element struct in the scan schema even though only two nested fields are 
required. This causes unnecessary Parquet and ORC input reads for wide array 
element schemas.
    
    This change addresses 
[SPARK-57022](https://issues.apache.org/jira/browse/SPARK-57022).
    
    ### What changes were proposed in this pull request?
    
    - Recognize statically identifiable nested field reads through the element 
variable of `ArrayTransform`.
    - Build a projected array element schema from exactly those referenced 
fields and propagate it to the scan input.
    - Rewrite the bound lambda variable type and `GetStructField` ordinals 
against the projected element schema after pruning.
    - Fall back to retaining the full element schema when the lambda consumes 
the complete element, so pruning is applied only when it is safe.
    - Add Catalyst and datasource tests covering ordinal rewrites, deep 
nesting, nested input paths with null values, indexed lambdas, case-insensitive 
resolution, and conservative fallback.
    
    The implementation intentionally has two stages. `SchemaPruning` discovers 
which fields the lambda needs from the array element. `ProjectionOverSchema` 
then rewrites the lambda against the narrower element type because pruning can 
change field ordinals. For example, pruning `struct<a, b, c>` to `struct<a, c>` 
moves `c` from ordinal `2` to ordinal `1`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Eligible queries using `transform` over arrays of structs can read a 
narrower input schema. Query results and SQL APIs are unchanged.
    
    ### How was this patch tested?
    
    - `build/sbt "catalyst/testOnly 
org.apache.spark.sql.catalyst.expressions.SchemaPruningSuite"`
    - `build/sbt "sql/testOnly 
org.apache.spark.sql.execution.datasources.parquet.ParquetV1SchemaPruningSuite 
org.apache.spark.sql.execution.datasources.parquet.ParquetV2SchemaPruningSuite 
org.apache.spark.sql.execution.datasources.orc.OrcV1SchemaPruningSuite 
org.apache.spark.sql.execution.datasources.orc.OrcV2SchemaPruningSuite -- -z 
ArrayTransform"`
    - `git diff --check apache/master...HEAD`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Codex (GPT-5)
    
    Closes #56070 from 
sunchao/dev/chao/codex/spark-57022-array-transform-pruning.
    
    Authored-by: Chao Sun <[email protected]>
    Signed-off-by: Chao Sun <[email protected]>
---
 .../expressions/ProjectionOverSchema.scala         |  60 ++++++++++++
 .../sql/catalyst/expressions/SchemaPruning.scala   |  91 +++++++++++++++++
 .../sql/catalyst/expressions/SelectedField.scala   |  15 +++
 .../catalyst/expressions/SchemaPruningSuite.scala  |  42 ++++++++
 .../execution/datasources/SchemaPruningSuite.scala | 109 +++++++++++++++++++++
 5 files changed, 317 insertions(+)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
index bb67c173b946..362643016d83 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
@@ -68,6 +68,31 @@ case class ProjectionOverSchema(schema: StructType, output: 
AttributeSet) {
         getProjection(child).map { projection => MapValues(projection) }
       case GetMapValue(child, key) =>
         getProjection(child).map { projection => GetMapValue(projection, key) }
+      case transform @ ArrayTransform(argument, lambda: LambdaFunction) =>
+        getProjection(argument).map {
+          case projection @ ArrayTypeProjection(projectedElementSchema) =>
+            lambda.arguments.headOption match {
+              case Some(elementVar: NamedLambdaVariable) =>
+                // Pruning fields changes the physical ordinal layout of the 
element struct.
+                // For example, pruning struct<a, b, c> to struct<a, c> moves 
c from ordinal 2
+                // to ordinal 1, so rewrite both the variable type and its 
field accesses.
+                val projectedElementVar = elementVar.copy(dataType = 
projectedElementSchema)
+                val lambdaProjection =
+                  ProjectionOverLambdaVariable(elementVar, projectedElementVar)
+                val projectedBody = lambda.function.transformDown {
+                  case lambdaProjection(expr) => expr
+                }
+                transform.copy(
+                  argument = projection,
+                  function = lambda.copy(
+                    function = projectedBody,
+                    arguments = projectedElementVar +: lambda.arguments.tail))
+              case _ =>
+                transform.copy(argument = projection)
+            }
+          case projection =>
+            transform.copy(argument = projection)
+        }
       case GetStructFieldObject(child, field: StructField) =>
         getProjection(child).map(p => (p, p.dataType)).map {
           case (projection, projSchema: StructType) =>
@@ -82,4 +107,39 @@ case class ProjectionOverSchema(schema: StructType, output: 
AttributeSet) {
       case _ =>
         None
     }
+
+  private object ArrayTypeProjection {
+    def unapply(expr: Expression): Option[StructType] = expr.dataType match {
+      case ArrayType(projectedElementSchema: StructType, _) => 
Some(projectedElementSchema)
+      case _ => None
+    }
+  }
+
+  /**
+   * Rewrites references rooted at one bound lambda element to use its 
projected type and
+   * recomputes nested field ordinals against each projected struct in the 
access path.
+   * This must support the same access paths collected by `SchemaPruning` for 
lambda variables;
+   * currently both sides support only `GetStructField` chains.
+   */
+  private case class ProjectionOverLambdaVariable(
+      original: NamedLambdaVariable,
+      projected: NamedLambdaVariable) {
+    def unapply(expr: Expression): Option[Expression] = project(expr)
+
+    private def project(expr: Expression): Option[Expression] = expr match {
+      case variable: NamedLambdaVariable if variable.semanticEquals(original) 
=>
+        Some(projected)
+      case GetStructFieldObject(child, field: StructField) =>
+        project(child).map { projection =>
+          projection.dataType match {
+            case projectedSchema: StructType =>
+              GetStructField(projection, 
projectedSchema.fieldIndex(field.name))
+            case dataType =>
+              throw SparkException.internalError(
+                s"unmatched lambda child schema for GetStructField: 
${dataType.toString}")
+          }
+        }
+      case _ => None
+    }
+  }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
index dd2d6c2cb610..2f99dd54f77a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
@@ -140,6 +140,19 @@ object SchemaPruning extends SQLConfHelper {
    */
   private[catalyst] def getRootFields(expr: Expression): Seq[RootField] = {
     expr match {
+      case ArrayTransform(argument, lambda: LambdaFunction) =>
+        // Field accesses through the lambda variable are not directly rooted 
at the input
+        // attribute. Convert them into a projected type for the transform 
argument so that
+        // physical nested column pruning can see them.
+        val nestedRootFields = lambda.arguments.headOption.collect {
+          case elementVar: NamedLambdaVariable =>
+            getArrayTransformRootField(argument, lambda.function, elementVar)
+        }.flatten.toSeq.map(field => RootField(field, derivedFromAtt = false))
+        if (nestedRootFields.nonEmpty) {
+          nestedRootFields ++ getRootFields(lambda.function)
+        } else {
+          expr.children.flatMap(getRootFields)
+        }
       case att: Attribute =>
         RootField(StructField(att.name, att.dataType, att.nullable, 
att.metadata),
           derivedFromAtt = true) :: Nil
@@ -162,6 +175,84 @@ object SchemaPruning extends SQLConfHelper {
     }
   }
 
+  private def getArrayTransformRootField(
+      argument: Expression,
+      function: Expression,
+      elementVar: NamedLambdaVariable): Option[StructField] = {
+    argument.dataType match {
+      case ArrayType(_: StructType, containsNull) =>
+        val selectedFields = collectLambdaVariableFields(function, elementVar)
+        if (selectedFields.exists(_.nonEmpty)) {
+          val mergedElementSchema = selectedFields
+            .get
+            .map(field => StructType(Array(field)))
+            .reduceLeft(_ merge _)
+          SelectedField.withDataType(
+            argument,
+            ArrayType(mergedElementSchema, containsNull))
+        } else {
+          None
+        }
+      case _ => None
+    }
+  }
+
+  /**
+   * Collects statically identifiable nested fields read from `elementVar`.
+   *
+   * `Some(Seq.empty)` means this subtree does not reference the element 
variable, and
+   * `Some(fields)` means every reference can be satisfied by the listed 
nested fields. `None`
+   * means the full element is required somewhere (for example, `x => 
struct(x.a, x)`), so it is
+   * not safe to prune the element struct.
+   *
+   * Currently only `GetStructField` chains rooted at `elementVar` are 
collected; array or map
+   * traversal within the lambda conservatively requires the full element. 
Keep this set of
+   * supported paths in sync with `ProjectionOverLambdaVariable` in 
`ProjectionOverSchema`.
+   */
+  private def collectLambdaVariableFields(
+      expr: Expression,
+      elementVar: NamedLambdaVariable): Option[Seq[StructField]] = {
+    expr match {
+      case LambdaVariableField(field, variable) if 
variable.semanticEquals(elementVar) =>
+        Some(field :: Nil)
+      case variable: NamedLambdaVariable if 
variable.semanticEquals(elementVar) =>
+        None
+      case _ =>
+        expr.children.foldLeft(Option(Seq.empty[StructField])) {
+          case (Some(fields), child) =>
+            collectLambdaVariableFields(child, elementVar).map(fields ++ _)
+          case (None, _) => None
+        }
+    }
+  }
+
+  /**
+   * Converts a field access rooted at the lambda element into the single 
nested
+   * [[StructField]] shape needed by the input array schema. For example,
+   * `x.company.address` becomes `company: struct<address: ...>`.
+   */
+  private object LambdaVariableField {
+    def unapply(expr: Expression): Option[(StructField, NamedLambdaVariable)] 
= {
+      def selectField(
+          expression: Expression,
+          dataTypeOpt: Option[DataType]): Option[(StructField, 
NamedLambdaVariable)] =
+        expression match {
+        case variable: NamedLambdaVariable =>
+          dataTypeOpt.collect {
+            case schema: StructType if schema.length == 1 =>
+              schema.head -> variable
+          }
+        case getStructField: GetStructField =>
+          val field = getStructField.childSchema(getStructField.ordinal)
+          val newField = field.copy(dataType = 
dataTypeOpt.getOrElse(field.dataType))
+          selectField(getStructField.child, Some(StructType(Array(newField))))
+        case _ => None
+      }
+
+      selectField(expr, None)
+    }
+  }
+
   /**
    * This represents a "root" schema field (aka top-level, no-parent). `field` 
is the
    * `StructField` for field name and datatype. `derivedFromAtt` indicates 
whether it
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
index 820dc452d7e8..e36224e7d5c1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
@@ -62,6 +62,21 @@ object SelectedField {
     selectField(unaliased, None)
   }
 
+  /**
+   * Builds the selected root field for `expr` while substituting a narrower 
projected data type.
+   * This lets a lambda field access establish the required scan type of its 
array argument, even
+   * though the lambda variable is not itself rooted at a data source 
attribute.
+   */
+  private[catalyst] def withDataType(
+      expr: Expression,
+      dataType: DataType): Option[StructField] = {
+    val unaliased = expr match {
+      case Alias(child, _) => child
+      case expression => expression
+    }
+    selectField(unaliased, Some(dataType))
+  }
+
   /**
    * Convert an expression into the parts of the schema (the field) it 
accesses.
    */
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
index a968526a89f1..9426ef91349e 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
@@ -144,4 +144,46 @@ class SchemaPruningSuite extends SparkFunSuite with 
SQLHelper {
     assert(prunedSchema.head.metadata.getString("foo") == "bar")
   }
 
+  test("collect nested fields used by ArrayTransform lambda") {
+    val elementType = StructType.fromDDL("a int, b int, c int")
+    val eventType = StructType(Seq(
+      StructField("rules", ArrayType(elementType, containsNull = true))))
+    val event = AttributeReference("event", eventType)()
+    val element = NamedLambdaVariable("x", elementType, nullable = true)
+    val transformed = ArrayTransform(
+      GetStructField(event, 0, Some("rules")),
+      LambdaFunction(
+        CreateNamedStruct(Seq(
+          Literal("a"),
+          GetStructField(element, 0, Some("a")),
+          Literal("c"),
+          GetStructField(element, 2, Some("c")))),
+        Seq(element)))
+
+    val rootFields = SchemaPruning.getRootFields(transformed)
+    val prunedSchema = SchemaPruning.pruneSchema(
+      StructType(Seq(StructField("event", eventType))),
+      rootFields)
+
+    assert(prunedSchema === StructType.fromDDL(
+      "event struct<rules:array<struct<a:int,c:int>>>"))
+  }
+
+  test("do not collect ArrayTransform lambda fields when the whole element is 
used") {
+    val elementType = StructType.fromDDL("a int, b int")
+    val eventType = StructType(Seq(
+      StructField("rules", ArrayType(elementType, containsNull = true))))
+    val event = AttributeReference("event", eventType)()
+    val element = NamedLambdaVariable("x", elementType, nullable = true)
+    val transformed = ArrayTransform(
+      GetStructField(event, 0, Some("rules")),
+      LambdaFunction(element, Seq(element)))
+
+    val rootFields = SchemaPruning.getRootFields(transformed)
+
+    assert(rootFields === Seq(
+      SchemaPruning.RootField(
+        StructField("event", eventType, nullable = true),
+        derivedFromAtt = false)))
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index fd8d1308e990..5213c0c5f4e2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -62,6 +62,8 @@ abstract class SchemaPruningSuite
     super.sparkConf.set(SQLConf.ANSI_ENABLED.key, "false")
 
   case class Employee(id: Int, name: FullName, employer: Company)
+  case class Team(members: Array[Employer])
+  case class Organization(team: Team)
 
   val janeDoe = FullName("Jane", "X.", "Doe")
   val johnDoe = FullName("John", "Y.", "Doe")
@@ -86,6 +88,9 @@ abstract class SchemaPruningSuite
     Department(2, "Operation", 4, employerWithNullCompany2) :: Nil
 
   val employees = Employee(0, janeDoe, company) :: Employee(1, johnDoe, 
company) :: Nil
+  val teams = Team(Array(employer)) :: Nil
+  val organizations =
+    Organization(Team(Array[Employer](null, employerWithNullCompany, 
employer))) :: Nil
 
   case class Name(first: String, last: String)
   case class BriefContact(id: Int, name: Name, address: String)
@@ -361,6 +366,110 @@ abstract class SchemaPruningSuite
     }
   }
 
+  testSchemaPruning("select ArrayTransform over nested fields of array of 
struct") {
+    val query = spark.table("contacts")
+      .where("p = 1")
+      .select(transform(col("friends"), friend =>
+        struct(
+          friend.getField("first").as("first"),
+          friend.getField("last").as("last"))))
+
+    checkScan(query, "struct<friends:array<struct<first:string,last:string>>>")
+    checkAnswer(query,
+      Row(Array(Row("Susan", "Smith"))) ::
+      Row(Array.empty[Row]) ::
+      Nil)
+  }
+
+  testSchemaPruning("select ArrayTransform over deep nested fields of array of 
struct") {
+    withDataSourceTable(teams, "teams") {
+      val query = spark.table("teams")
+        .select(transform(col("members"), member =>
+          member.getField("company").getField("address")))
+
+      checkScan(query, 
"struct<members:array<struct<company:struct<address:string>>>>")
+      checkAnswer(query, Row(Array("123 Business Street")) :: Nil)
+    }
+  }
+
+  testSchemaPruning("select ArrayTransform merging nested parent and child 
fields") {
+    withDataSourceTable(teams, "teams") {
+      val query = spark.table("teams")
+        .select(transform(col("members"), member =>
+          struct(
+            member.getField("company").as("company"),
+            member.getField("company").getField("address").as("address"))))
+
+      checkScan(query, 
"struct<members:array<struct<company:struct<name:string,address:string>>>>")
+      checkAnswer(query,
+        Row(Array(Row(Row("abc", "123 Business Street"), "123 Business 
Street"))) :: Nil)
+    }
+  }
+
+  testSchemaPruning("select ArrayTransform over nested array path with null 
elements") {
+    withDataSourceTable(organizations, "organizations") {
+      val query = spark.table("organizations")
+        .select(transform(col("team.members"), member =>
+          member.getField("company").getField("address")))
+
+      checkScan(query,
+        
"struct<team:struct<members:array<struct<company:struct<address:string>>>>>")
+      checkAnswer(query, Row(Array(null, null, "123 Business Street")) :: Nil)
+    }
+  }
+
+  testSchemaPruning("select indexed ArrayTransform over nested fields of array 
of struct") {
+    val query = spark.table("contacts")
+      .where("p = 1")
+      .select(transform(col("friends"), (friend, index) =>
+        struct(friend.getField("last").as("last"), index.as("index"))))
+
+    checkScan(query, "struct<friends:array<struct<last:string>>>")
+    checkAnswer(query,
+      Row(Array(Row("Smith", 0))) ::
+      Row(Array.empty[Row]) ::
+      Nil)
+  }
+
+  testSchemaPruning("select case-insensitive ArrayTransform nested field") {
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+      val query = spark.table("contacts")
+        .where("p = 1")
+        .select(transform(col("friends"), friend => friend.getField("LaSt")))
+
+      checkScan(query, "struct<friends:array<struct<last:string>>>")
+      checkAnswer(query,
+        Row(Array("Smith")) ::
+        Row(Array.empty[String]) ::
+        Nil)
+    }
+  }
+
+  testSchemaPruning("do not prune ArrayTransform when the whole element is 
used") {
+    val query = spark.table("contacts")
+      .where("p = 1")
+      .select(transform(col("friends"), friend => friend))
+
+    checkScan(query, 
"struct<friends:array<struct<first:string,middle:string,last:string>>>")
+    checkAnswer(query,
+      Row(Array(Row("Susan", "Z.", "Smith"))) ::
+      Row(Array.empty[Row]) ::
+      Nil)
+  }
+
+  testSchemaPruning("do not prune ArrayTransform when a nested field and whole 
element are used") {
+    val query = spark.table("contacts")
+      .where("p = 1")
+      .select(transform(col("friends"), friend =>
+        struct(friend.getField("first").as("first"), friend.as("friend"))))
+
+    checkScan(query, 
"struct<friends:array<struct<first:string,middle:string,last:string>>>")
+    checkAnswer(query,
+      Row(Array(Row("Susan", Row("Susan", "Z.", "Smith")))) ::
+      Row(Array.empty[Row]) ::
+      Nil)
+  }
+
   testSchemaPruning("SPARK-34638: nested column prune on generator output") {
     val query1 = spark.table("contacts")
       .select(explode(col("friends")).as("friend"))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to