This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1f9eaf8f27a7 [SPARK-57175][SQL] Extend nested column pruning to exists 
and forall over arrays of structs
1f9eaf8f27a7 is described below

commit 1f9eaf8f27a781a8641a90888c1e6423909064cb
Author: Chao Sun <[email protected]>
AuthorDate: Tue Jun 2 12:18:27 2026 -0700

    [SPARK-57175][SQL] Extend nested column pruning to exists and forall over 
arrays of structs
    
    ### Why are the changes needed?
    
    [SPARK-57175](https://issues.apache.org/jira/browse/SPARK-57175) follows 
[SPARK-57022](https://issues.apache.org/jira/browse/SPARK-57022), which added 
nested column pruning for `transform` over `array<struct>` inputs. The same 
optimization does not currently apply to the `exists` and `forall` higher-order 
array functions.
    
    For example:
    
    ```sql
    SELECT exists(rule_results, rule -> rule.rule_version > 10)
    FROM events
    ```
    
    If `rule_results` contains additional fields, Spark currently retains the 
full element struct in the scan schema even though the predicate only reads 
`rule_version`. This causes unnecessary Parquet and ORC input reads for wide 
array element schemas.
    
    ### What changes were proposed in this PR?
    
    - Share the nested-field collection path introduced for `ArrayTransform` 
with `ArrayExists` and `ArrayForAll`.
    - Rewrite the bound lambda variable type and `GetStructField` ordinals 
against the projected element schema after pruning.
    - Keep the conservative fallback when a lambda consumes the whole element.
    - Add Catalyst and datasource tests covering schema discovery, ordinal 
rewrites, predicate-path schema merging, and whole-element fallback.
    
    `ArrayFilter` and `ArraySort` remain out of scope because they return 
original input elements and require a different downstream-schema design.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Eligible queries using `exists` or `forall` over arrays of structs can 
read a narrower input schema. Query results and SQL APIs are unchanged.
    
    ### How was this patch tested?
    
    - `JAVA_HOME=/opt/homebrew/opt/openjdk17/libexec/openjdk.jdk/Contents/Home 
PATH=/opt/homebrew/opt/openjdk17/bin:$PATH build/sbt "catalyst/testOnly 
org.apache.spark.sql.catalyst.expressions.SchemaPruningSuite" "sql/testOnly 
org.apache.spark.sql.execution.datasources.parquet.ParquetV1SchemaPruningSuite 
org.apache.spark.sql.execution.datasources.parquet.ParquetV2SchemaPruningSuite 
org.apache.spark.sql.execution.datasources.orc.OrcV1SchemaPruningSuite 
org.apache.spark.sql.execution.dataso [...]
    - `git diff --check`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Codex (GPT-5)
    
    Closes #56226 from sunchao/dev/chao/codex/spark-array-predicate-pruning.
    
    Authored-by: Chao Sun <[email protected]>
    Signed-off-by: Chao Sun <[email protected]>
---
 .../expressions/ProjectionOverSchema.scala         | 63 ++++++++++++++--------
 .../sql/catalyst/expressions/SchemaPruning.scala   | 37 ++++++++-----
 .../catalyst/expressions/SchemaPruningSuite.scala  | 39 ++++++++++++++
 .../execution/datasources/SchemaPruningSuite.scala | 48 +++++++++++++++++
 4 files changed, 151 insertions(+), 36 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
index 362643016d83..27e014ecef62 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
@@ -69,29 +69,16 @@ case class ProjectionOverSchema(schema: StructType, output: 
AttributeSet) {
       case GetMapValue(child, key) =>
         getProjection(child).map { projection => GetMapValue(projection, key) }
       case transform @ ArrayTransform(argument, lambda: LambdaFunction) =>
-        getProjection(argument).map {
-          case projection @ ArrayTypeProjection(projectedElementSchema) =>
-            lambda.arguments.headOption match {
-              case Some(elementVar: NamedLambdaVariable) =>
-                // Pruning fields changes the physical ordinal layout of the 
element struct.
-                // For example, pruning struct<a, b, c> to struct<a, c> moves 
c from ordinal 2
-                // to ordinal 1, so rewrite both the variable type and its 
field accesses.
-                val projectedElementVar = elementVar.copy(dataType = 
projectedElementSchema)
-                val lambdaProjection =
-                  ProjectionOverLambdaVariable(elementVar, projectedElementVar)
-                val projectedBody = lambda.function.transformDown {
-                  case lambdaProjection(expr) => expr
-                }
-                transform.copy(
-                  argument = projection,
-                  function = lambda.copy(
-                    function = projectedBody,
-                    arguments = projectedElementVar +: lambda.arguments.tail))
-              case _ =>
-                transform.copy(argument = projection)
-            }
-          case projection =>
-            transform.copy(argument = projection)
+        projectArrayHigherOrderFunction(argument, lambda) { (projection, 
projectedLambda) =>
+          transform.copy(argument = projection, function = projectedLambda)
+        }
+      case exists @ ArrayExists(argument, lambda: LambdaFunction, _) =>
+        projectArrayHigherOrderFunction(argument, lambda) { (projection, 
projectedLambda) =>
+          exists.copy(argument = projection, function = projectedLambda)
+        }
+      case forall @ ArrayForAll(argument, lambda: LambdaFunction) =>
+        projectArrayHigherOrderFunction(argument, lambda) { (projection, 
projectedLambda) =>
+          forall.copy(argument = projection, function = projectedLambda)
         }
       case GetStructFieldObject(child, field: StructField) =>
         getProjection(child).map(p => (p, p.dataType)).map {
@@ -108,6 +95,36 @@ case class ProjectionOverSchema(schema: StructType, output: 
AttributeSet) {
         None
     }
 
+  private def projectArrayHigherOrderFunction(
+      argument: Expression,
+      lambda: LambdaFunction)(
+      rebuild: (Expression, LambdaFunction) => Expression): Option[Expression] 
= {
+    getProjection(argument).map {
+      case projection @ ArrayTypeProjection(projectedElementSchema) =>
+        lambda.arguments.headOption match {
+          case Some(elementVar: NamedLambdaVariable) =>
+            // Pruning fields changes the physical ordinal layout of the 
element struct.
+            // For example, pruning struct<a, b, c> to struct<a, c> moves c 
from ordinal 2
+            // to ordinal 1, so rewrite both the variable type and its field 
accesses.
+            val projectedElementVar = elementVar.copy(dataType = 
projectedElementSchema)
+            val lambdaProjection =
+              ProjectionOverLambdaVariable(elementVar, projectedElementVar)
+            val projectedBody = lambda.function.transformDown {
+              case lambdaProjection(expr) => expr
+            }
+            rebuild(
+              projection,
+              lambda.copy(
+                function = projectedBody,
+                arguments = projectedElementVar +: lambda.arguments.tail))
+          case _ =>
+            rebuild(projection, lambda)
+        }
+      case projection =>
+        rebuild(projection, lambda)
+    }
+  }
+
   private object ArrayTypeProjection {
     def unapply(expr: Expression): Option[StructType] = expr.dataType match {
       case ArrayType(projectedElementSchema: StructType, _) => 
Some(projectedElementSchema)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
index 2f99dd54f77a..e8aa722bbe23 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
@@ -141,18 +141,11 @@ object SchemaPruning extends SQLConfHelper {
   private[catalyst] def getRootFields(expr: Expression): Seq[RootField] = {
     expr match {
       case ArrayTransform(argument, lambda: LambdaFunction) =>
-        // Field accesses through the lambda variable are not directly rooted 
at the input
-        // attribute. Convert them into a projected type for the transform 
argument so that
-        // physical nested column pruning can see them.
-        val nestedRootFields = lambda.arguments.headOption.collect {
-          case elementVar: NamedLambdaVariable =>
-            getArrayTransformRootField(argument, lambda.function, elementVar)
-        }.flatten.toSeq.map(field => RootField(field, derivedFromAtt = false))
-        if (nestedRootFields.nonEmpty) {
-          nestedRootFields ++ getRootFields(lambda.function)
-        } else {
-          expr.children.flatMap(getRootFields)
-        }
+        getArrayHigherOrderFunctionRootFields(expr, argument, lambda)
+      case ArrayExists(argument, lambda: LambdaFunction, _) =>
+        getArrayHigherOrderFunctionRootFields(expr, argument, lambda)
+      case ArrayForAll(argument, lambda: LambdaFunction) =>
+        getArrayHigherOrderFunctionRootFields(expr, argument, lambda)
       case att: Attribute =>
         RootField(StructField(att.name, att.dataType, att.nullable, 
att.metadata),
           derivedFromAtt = true) :: Nil
@@ -175,7 +168,25 @@ object SchemaPruning extends SQLConfHelper {
     }
   }
 
-  private def getArrayTransformRootField(
+  private def getArrayHigherOrderFunctionRootFields(
+      expr: Expression,
+      argument: Expression,
+      lambda: LambdaFunction): Seq[RootField] = {
+    // Field accesses through the lambda variable are not directly rooted at 
the input
+    // attribute. Convert them into a projected type for the array argument so 
that
+    // physical nested column pruning can see them.
+    val nestedRootFields = lambda.arguments.headOption.collect {
+      case elementVar: NamedLambdaVariable =>
+        getArrayHigherOrderFunctionRootField(argument, lambda.function, 
elementVar)
+    }.flatten.toSeq.map(field => RootField(field, derivedFromAtt = false))
+    if (nestedRootFields.nonEmpty) {
+      nestedRootFields ++ getRootFields(lambda.function)
+    } else {
+      expr.children.flatMap(getRootFields)
+    }
+  }
+
+  private def getArrayHigherOrderFunctionRootField(
       argument: Expression,
       function: Expression,
       elementVar: NamedLambdaVariable): Option[StructField] = {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
index 9426ef91349e..af64da7e3820 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
@@ -186,4 +186,43 @@ class SchemaPruningSuite extends SparkFunSuite with 
SQLHelper {
         StructField("event", eventType, nullable = true),
         derivedFromAtt = false)))
   }
+
+  test("collect nested fields used by ArrayExists and ArrayForAll lambdas") {
+    val elementType = StructType.fromDDL("a int, b int, c int")
+    val eventType = StructType(Seq(
+      StructField("rules", ArrayType(elementType, containsNull = true))))
+    val event = AttributeReference("event", eventType)()
+    val argument = GetStructField(event, 0, Some("rules"))
+    val element = NamedLambdaVariable("x", elementType, nullable = true)
+    val predicate = LambdaFunction(
+      GreaterThan(GetStructField(element, 2, Some("c")), Literal(0)),
+      Seq(element))
+
+    Seq(ArrayExists(argument, predicate), ArrayForAll(argument, 
predicate)).foreach { function =>
+      val rootFields = SchemaPruning.getRootFields(function)
+      val prunedSchema = SchemaPruning.pruneSchema(
+        StructType(Seq(StructField("event", eventType))),
+        rootFields)
+
+      assert(prunedSchema === StructType.fromDDL(
+        "event struct<rules:array<struct<c:int>>>"))
+    }
+  }
+
+  test("do not collect ArrayExists and ArrayForAll lambda fields when the 
whole element is used") {
+    val elementType = StructType.fromDDL("a int, b int")
+    val eventType = StructType(Seq(
+      StructField("rules", ArrayType(elementType, containsNull = true))))
+    val event = AttributeReference("event", eventType)()
+    val argument = GetStructField(event, 0, Some("rules"))
+    val element = NamedLambdaVariable("x", elementType, nullable = true)
+    val predicate = LambdaFunction(IsNotNull(element), Seq(element))
+
+    Seq(ArrayExists(argument, predicate), ArrayForAll(argument, 
predicate)).foreach { function =>
+      assert(SchemaPruning.getRootFields(function) === Seq(
+        SchemaPruning.RootField(
+          StructField("event", eventType, nullable = true),
+          derivedFromAtt = false)))
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index 7f4b83ee342e..2aebf08286e1 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -470,6 +470,54 @@ abstract class SchemaPruningSuite
       Nil)
   }
 
+  testSchemaPruning("select ArrayExists over nested fields of array of 
struct") {
+    val query = spark.table("contacts")
+      .where("p = 1")
+      .select(org.apache.spark.sql.functions.exists(
+        col("friends"), friend => friend.getField("last") === "Smith"))
+
+    checkScan(query, "struct<friends:array<struct<last:string>>>")
+    checkAnswer(query, Row(true) :: Row(false) :: Nil)
+  }
+
+  testSchemaPruning("select ArrayForAll over nested fields of array of 
struct") {
+    val query = spark.table("contacts")
+      .where("p = 1")
+      .select(forall(col("friends"), friend => friend.getField("last") === 
"Smith"))
+
+    checkScan(query, "struct<friends:array<struct<last:string>>>")
+    checkAnswer(query, Row(true) :: Row(true) :: Nil)
+  }
+
+  testSchemaPruning("select nested field with ArrayExists predicate") {
+    val query = spark.table("contacts")
+      .where("p = 1")
+      .where(org.apache.spark.sql.functions.exists(
+        col("friends"), friend => friend.getField("last") === "Smith"))
+      .select(col("friends").getField("first"))
+
+    checkScan(query, "struct<friends:array<struct<first:string,last:string>>>")
+    checkAnswer(query, Row(Array("Susan")) :: Nil)
+  }
+
+  testSchemaPruning("do not prune ArrayExists when the whole element is used") 
{
+    val query = spark.table("contacts")
+      .where("p = 1")
+      .select(org.apache.spark.sql.functions.exists(col("friends"), friend => 
friend.isNotNull))
+
+    checkScan(query, 
"struct<friends:array<struct<first:string,middle:string,last:string>>>")
+    checkAnswer(query, Row(true) :: Row(false) :: Nil)
+  }
+
+  testSchemaPruning("do not prune ArrayForAll when the whole element is used") 
{
+    val query = spark.table("contacts")
+      .where("p = 1")
+      .select(forall(col("friends"), friend => friend.isNotNull))
+
+    checkScan(query, 
"struct<friends:array<struct<first:string,middle:string,last:string>>>")
+    checkAnswer(query, Row(true) :: Row(true) :: Nil)
+  }
+
   testSchemaPruning("SPARK-34638: nested column prune on generator output") {
     val query1 = spark.table("contacts")
       .select(explode(col("friends")).as("friend"))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to