This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.1 by this push:
     new 18fdd026bd2e [SPARK-54523][SQL] Fix default resolution during variant 
pushdown
18fdd026bd2e is described below

commit 18fdd026bd2e201e43297845a6c28968b51429b8
Author: Harsh Motwani <[email protected]>
AuthorDate: Wed Nov 26 14:51:21 2025 +0800

    [SPARK-54523][SQL] Fix default resolution during variant pushdown
    
    ### What changes were proposed in this pull request?
    
    [This PR](https://github.com/apache/spark/pull/53164) enables shredding and 
variant logical type annotation configs by default. However, some test suites 
assume the old behavior. This PR fixes those tests to also work with the new 
default configs.
    
    This PR also fixes a bug we discovered in the previous PR where variant 
default resolution would fail when pushVariantIntoScan was enabled.
    
    ### Why are the changes needed?
    
    To fix the bug.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #53224 from harshmotw-db/harshmotw-db/shredding_fixes.
    
    Lead-authored-by: Harsh Motwani <[email protected]>
    Co-authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit d36bd625332cacbd33102de4ab2d0c9574d6de12)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/expressions/Cast.scala      | 25 +++++++++
 .../catalyst/util/ResolveDefaultColumnsUtil.scala  |  2 +-
 .../scala/org/apache/spark/sql/VariantSuite.scala  | 64 ++++++++++++----------
 .../parquet/ParquetVariantShreddingSuite.scala     | 12 ++--
 .../parquet/VariantInferShreddingSuite.scala       |  3 +
 5 files changed, 73 insertions(+), 33 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 1f2805ec2789..1162a5394221 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -360,6 +360,31 @@ object Cast extends QueryErrorsBase {
    */
   def canUpCast(from: DataType, to: DataType): Boolean = 
UpCastRule.canUpCast(from, to)
 
+  /**
+   * Returns true iff it is safe to provide a default value of `from` type 
typically defined in the
+   * data source metadata to the `to` type typically in the read schema of a 
query.
+   */
+  def canAssignDefaultValue(from: DataType, to: DataType): Boolean = {
+    def isVariantStruct(st: StructType): Boolean = {
+      st.fields.length > 0 && 
st.fields.forall(_.metadata.contains("__VARIANT_METADATA_KEY"))
+    }
+    (from, to) match {
+      case (s1: StructType, s2: StructType) =>
+        s1.length == s2.length && s1.fields.zip(s2.fields).forall {
+          case (f1, f2) => resolvableNullability(f1.nullable, f2.nullable) &&
+            canAssignDefaultValue(f1.dataType, f2.dataType)
+        }
+      case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
+        resolvableNullability(fn, tn) && canAssignDefaultValue(fromType, 
toType)
+      case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
+        resolvableNullability(fn, tn) && canAssignDefaultValue(fromKey, toKey) 
&&
+          canAssignDefaultValue(fromValue, toValue)
+      // A VARIANT field can be read as StructType due to shredding.
+      case (VariantType, s: StructType) => isVariantStruct(s)
+      case _ => canUpCast(from, to)
+    }
+  }
+
   /**
    * Returns true iff we can cast the `from` type to `to` type as per the ANSI 
SQL.
    * In practice, the behavior is mostly the same as PostgreSQL. It disallows 
certain unreasonable
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
index 4bef21d0a091..488d1acf43ac 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
@@ -480,7 +480,7 @@ object ResolveDefaultColumns extends QueryErrorsBase
     val ret = analyzed match {
       case equivalent if equivalent.dataType == supplanted =>
         equivalent
-      case canUpCast if Cast.canUpCast(canUpCast.dataType, supplanted) =>
+      case _ if Cast.canAssignDefaultValue(analyzed.dataType, supplanted) =>
         Cast(analyzed, supplanted, Some(conf.sessionLocalTimeZone))
       case other =>
         defaultValueFromWiderTypeLiteral(other, supplanted, colName).getOrElse(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
index ac6a4e435709..16be9558409c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
@@ -197,36 +197,39 @@ class VariantSuite extends QueryTest with 
SharedSparkSession with ExpressionEval
   }
 
   test("round trip tests") {
-    val rand = new Random(42)
-    val input = Seq.fill(50) {
-      if (rand.nextInt(10) == 0) {
-        null
-      } else {
-        val value = new Array[Byte](rand.nextInt(50))
-        rand.nextBytes(value)
-        val metadata = new Array[Byte](rand.nextInt(50))
-        rand.nextBytes(metadata)
-        // Generate a valid metadata, otherwise the shredded reader will fail.
-        new VariantVal(value, Array[Byte](VERSION, 0, 0) ++ metadata)
+    withSQLConf(SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key -> "false") {
+      val rand = new Random(42)
+      val input = Seq.fill(50) {
+        if (rand.nextInt(10) == 0) {
+          null
+        } else {
+          val value = new Array[Byte](rand.nextInt(50))
+          rand.nextBytes(value)
+          val metadata = new Array[Byte](rand.nextInt(50))
+          rand.nextBytes(metadata)
+          // Generate a valid metadata, otherwise the shredded reader will 
fail.
+          new VariantVal(value, Array[Byte](VERSION, 0, 0) ++ metadata)
+        }
       }
-    }
 
-    val df = spark.createDataFrame(
-      spark.sparkContext.parallelize(input.map(Row(_))),
-      StructType.fromDDL("v variant")
-    )
-    val result = df.collect().map(_.get(0).asInstanceOf[VariantVal])
+      val df = spark.createDataFrame(
+        spark.sparkContext.parallelize(input.map(Row(_))),
+        StructType.fromDDL("v variant")
+      )
+      val result = df.collect().map(_.get(0).asInstanceOf[VariantVal])
 
-    def prepareAnswer(values: Seq[VariantVal]): Seq[String] = {
-      values.map(v => if (v == null) "null" else v.debugString()).sorted
-    }
-    assert(prepareAnswer(input) == prepareAnswer(result.toImmutableArraySeq))
+      def prepareAnswer(values: Seq[VariantVal]): Seq[String] = {
+        values.map(v => if (v == null) "null" else v.debugString()).sorted
+      }
+      assert(prepareAnswer(input) == prepareAnswer(result.toImmutableArraySeq))
 
-    withTempDir { dir =>
-      val tempDir = new File(dir, "files").getCanonicalPath
-      df.write.parquet(tempDir)
-      val readResult = 
spark.read.parquet(tempDir).collect().map(_.get(0).asInstanceOf[VariantVal])
-      assert(prepareAnswer(input) == 
prepareAnswer(readResult.toImmutableArraySeq))
+      withTempDir { dir =>
+        val tempDir = new File(dir, "files").getCanonicalPath
+        df.write.parquet(tempDir)
+        val readResult = spark.read.parquet(tempDir).collect()
+          .map(_.get(0).asInstanceOf[VariantVal])
+        assert(prepareAnswer(input) == 
prepareAnswer(readResult.toImmutableArraySeq))
+      }
     }
   }
 
@@ -383,14 +386,19 @@ class VariantSuite extends QueryTest with 
SharedSparkSession with ExpressionEval
     )
     cases.foreach { case (structDef, condition, parameters) =>
       Seq(false, true).foreach { vectorizedReader =>
-        withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> 
vectorizedReader.toString) {
+        withSQLConf(
+          SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> 
vectorizedReader.toString,
+          // Invalid variant binary fails during shredding schema inference.
+          SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key -> "false"
+        ) {
           withTempDir { dir =>
             val file = new File(dir, "dir").getCanonicalPath
             val df = spark.sql(s"select $structDef as v from range(10)")
             df.write.parquet(file)
             val schema = StructType(Seq(StructField("v", VariantType)))
             val result = 
spark.read.schema(schema).parquet(file).selectExpr("to_json(v)")
-            val e = withSQLConf(SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> 
"false") {
+            val e = withSQLConf(SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> 
"false",
+              SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "false") {
               intercept[org.apache.spark.SparkException](result.collect())
             }
             checkError(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala
index 77140c1a91ee..1f06ddb29bd4 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala
@@ -48,7 +48,8 @@ class ParquetVariantShreddingSuite extends QueryTest with 
ParquetTest with Share
 
   test("timestamp physical type") {
     ParquetOutputTimestampType.values.foreach { timestampParquetType =>
-      withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> 
timestampParquetType.toString) {
+      withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> 
timestampParquetType.toString,
+        SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> "true") {
         withTempDir { dir =>
           val schema = "t timestamp, st struct<t timestamp>, at 
array<timestamp>"
           val fullSchema = "v struct<metadata binary, value binary, 
typed_value struct<" +
@@ -232,7 +233,8 @@ class ParquetVariantShreddingSuite extends QueryTest with 
ParquetTest with Share
   test("variant logical type annotation - ignore variant annotation") {
     Seq(true, false).foreach { ignoreVariantAnnotation =>
       withSQLConf(SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.key -> "true",
-        SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> 
ignoreVariantAnnotation.toString
+        SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> 
ignoreVariantAnnotation.toString,
+        SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key -> "false"
       ) {
         withTempDir { dir =>
           // write parquet file
@@ -302,7 +304,8 @@ class ParquetVariantShreddingSuite extends QueryTest with 
ParquetTest with Share
       "c struct<value binary, typed_value decimal(15, 1)>>>"
     withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> true.toString,
       SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> true.toString,
-      SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> schema) {
+      SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> schema,
+      SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> true.toString) {
       df.write.mode("overwrite").parquet(dir.getAbsolutePath)
 
 
@@ -441,7 +444,8 @@ class ParquetVariantShreddingSuite extends QueryTest with 
ParquetTest with Share
       "m map<string, struct<metadata binary, value binary>>"
     withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> true.toString,
       SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> true.toString,
-      SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> schema) {
+      SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> schema,
+      SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> true.toString) {
       df.write.mode("overwrite").parquet(dir.getAbsolutePath)
 
       // Verify that we can read the full variant.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala
index cdaf6c488dc2..49a43fffafb3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala
@@ -41,6 +41,9 @@ class VariantInferShreddingSuite extends QueryTest with 
SharedSparkSession with
     super.sparkConf.set(SQLConf.PUSH_VARIANT_INTO_SCAN.key, "true")
       .set(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key, "true")
       .set(SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key, "true")
+      // We cannot check the physical shredding schemas if the variant logical 
type annotation is
+      // used
+      .set(SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.key, "false")
   }
 
   private def withTempTable(tableNames: String*)(f: => Unit): Unit = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to