This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5c634845f084 [SPARK-50614][SQL] Add Variant shredding support for 
Parquet
5c634845f084 is described below

commit 5c634845f0849b3d5bf6a98c4ecddff26b71572b
Author: cashmand <[email protected]>
AuthorDate: Thu Jan 2 12:39:07 2025 +0800

    [SPARK-50614][SQL] Add Variant shredding support for Parquet
    
    ### What changes were proposed in this pull request?
    
    Adds support for shredding in the Parquet writer code. Currently, the only 
way to enable shredding is through a SQLConf that provides the schema to use 
for shredding. This doesn't make sense as a user API, and is added only for 
testing. The exact API for Spark to determine a shredding schema is still TBD, 
but likely candidates are to infer it at the task level by inspecting the first 
few rows of data, or add an API to specify the schema for a given column. 
Either way, the code in this [...]
    
    ### Why are the changes needed?
    
    Needed for Variant shredding support.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, the feature is new in Spark 4.0, and is currently disabled, and only 
usable as a test feature.
    
    ### How was this patch tested?
    
    Added a unit test suite.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49234 from cashmand/SPARK-50614.
    
    Authored-by: cashmand <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../org/apache/spark/sql/internal/SQLConf.scala    |  16 ++
 .../parquet/ParquetSchemaConverter.scala           |   8 +
 .../datasources/parquet/ParquetUtils.scala         |  34 ++-
 .../datasources/parquet/ParquetWriteSupport.scala  |  65 +++++-
 .../datasources/parquet/SparkShreddingUtils.scala  |  23 +++
 .../parquet/ParquetVariantShreddingSuite.scala     | 229 +++++++++++++++++++++
 6 files changed, 362 insertions(+), 13 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 5e630577638a..6a45380b7a99 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4659,6 +4659,22 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val VARIANT_WRITE_SHREDDING_ENABLED =
+    buildConf("spark.sql.variant.writeShredding.enabled")
+      .internal()
+      .doc("When true, the Parquet writer is allowed to write shredded 
variant. ")
+      .version("4.0.0")
+      .booleanConf
+      .createWithDefault(false)
+
+  val VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST =
+    buildConf("spark.sql.variant.forceShreddingSchemaForTest")
+      .internal()
+      .doc("FOR INTERNAL TESTING ONLY. Sets shredding schema for Variant.")
+      .version("4.0.0")
+      .stringConf
+      .createWithDefault("")
+
   val LEGACY_CSV_ENABLE_DATE_TIME_PARSING_FALLBACK =
     buildConf("spark.sql.legacy.csv.enableDateTimeParsingFallback")
       .internal()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
index 64c2a3126ca9..daeb8e88a924 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
@@ -756,6 +756,14 @@ class SparkToParquetSchemaConverter(
           .addField(convertField(StructField("metadata", BinaryType, nullable 
= false)))
           .named(field.name)
 
+      case s: StructType if SparkShreddingUtils.isVariantShreddingStruct(s) =>
+        // Variant struct takes a Variant and writes to Parquet as a shredded 
schema.
+        val group = Types.buildGroup(repetition)
+        s.fields.foreach { f =>
+          group.addField(convertField(f))
+        }
+        group.named(field.name)
+
       case StructType(fields) =>
         fields.foldLeft(Types.buildGroup(repetition)) { (builder, field) =>
           builder.addField(convertField(field))
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
index 3e111252bc6f..a609a4e0a25f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
@@ -45,7 +45,7 @@ import 
org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, Outpu
 import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils
 import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
 import org.apache.spark.sql.internal.SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED
-import org.apache.spark.sql.types.{ArrayType, AtomicType, DataType, MapType, 
StructField, StructType, UserDefinedType}
+import org.apache.spark.sql.types.{ArrayType, AtomicType, DataType, MapType, 
StructField, StructType, UserDefinedType, VariantType}
 import org.apache.spark.util.ArrayImplicits._
 
 object ParquetUtils extends Logging {
@@ -420,6 +420,22 @@ object ParquetUtils extends Logging {
     statistics.getNumNulls;
   }
 
+  // Replaces each VariantType in the schema with the corresponding type in 
the shredding schema.
+  // Used for testing, where we force a single shredding schema for all 
Variant fields.
+  // Does not touch Variant fields nested in arrays, maps, or UDTs.
+  private def replaceVariantTypes(schema: StructType, shreddingSchema: 
StructType): StructType = {
+    val newFields = schema.fields.zip(shreddingSchema.fields).map {
+      case (field, shreddingField) =>
+        field.dataType match {
+          case s: StructType =>
+            field.copy(dataType = replaceVariantTypes(s, shreddingSchema))
+          case VariantType => field.copy(dataType = shreddingSchema)
+          case _ => field
+        }
+    }
+    StructType(newFields)
+  }
+
   def prepareWrite(
       sqlConf: SQLConf,
       job: Job,
@@ -454,8 +470,22 @@ object ParquetUtils extends Logging {
 
     ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport])
 
+    val shreddingSchema = if 
(sqlConf.getConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED) &&
+        
!sqlConf.getConf(SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST).isEmpty) {
+      // Convert the schema to a shredding schema, and replace it anywhere 
that there is a
+      // VariantType in the original schema.
+      val simpleShreddingSchema = DataType.fromDDL(
+        sqlConf.getConf(SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST)
+      )
+      val oneShreddingSchema = 
SparkShreddingUtils.variantShreddingSchema(simpleShreddingSchema)
+      val schemaWithMetadata = 
SparkShreddingUtils.addWriteShreddingMetadata(oneShreddingSchema)
+      Some(replaceVariantTypes(dataSchema, schemaWithMetadata))
+    } else {
+      None
+    }
+
     // This metadata is useful for keeping UDTs like Vector/Matrix.
-    ParquetWriteSupport.setSchema(dataSchema, conf)
+    ParquetWriteSupport.setSchema(dataSchema, conf, shreddingSchema)
 
     // Sets flags for `ParquetWriteSupport`, which converts Catalyst schema to 
Parquet
     // schema and writes actual rows to Parquet files.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
index 89a1cd5d4375..9402f5638094 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution.datasources.DataSourceUtils
 import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
 import org.apache.spark.sql.types._
+import org.apache.spark.types.variant.Variant
 
 /**
  * A Parquet [[WriteSupport]] implementation that writes Catalyst 
[[InternalRow]]s as Parquet
@@ -59,6 +60,10 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] 
with Logging {
   // Schema of the `InternalRow`s to be written
   private var schema: StructType = _
 
+  // Schema of the `InternalRow`s to be written, with VariantType replaced 
with its shredding
+  // schema, if appropriate.
+  private var shreddedSchema: StructType = _
+
   // `ValueWriter`s for all fields of the schema
   private var rootFieldWriters: Array[ValueWriter] = _
 
@@ -95,7 +100,16 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] 
with Logging {
 
   override def init(configuration: Configuration): WriteContext = {
     val schemaString = configuration.get(ParquetWriteSupport.SPARK_ROW_SCHEMA)
+    val shreddedSchemaString = 
configuration.get(ParquetWriteSupport.SPARK_VARIANT_SHREDDING_SCHEMA)
     this.schema = StructType.fromString(schemaString)
+    // If shreddingSchemaString is provided, we use that everywhere in the 
writer, except for
+    // setting the spark schema in the Parquet metadata. If it isn't provided, 
it means that there
+    // are no shredded Variant columns, so it is identical to this.schema.
+    this.shreddedSchema = if (shreddedSchemaString == null) {
+      this.schema
+    } else {
+      StructType.fromString(shreddedSchemaString)
+    }
     this.writeLegacyParquetFormat = {
       // `SQLConf.PARQUET_WRITE_LEGACY_FORMAT` should always be explicitly set 
in ParquetRelation
       assert(configuration.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key) != 
null)
@@ -108,9 +122,9 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] 
with Logging {
       SQLConf.ParquetOutputTimestampType.withName(configuration.get(key))
     }
 
-    this.rootFieldWriters = 
schema.map(_.dataType).map(makeWriter).toArray[ValueWriter]
+    this.rootFieldWriters = 
shreddedSchema.map(_.dataType).map(makeWriter).toArray[ValueWriter]
 
-    val messageType = new 
SparkToParquetSchemaConverter(configuration).convert(schema)
+    val messageType = new 
SparkToParquetSchemaConverter(configuration).convert(shreddedSchema)
     val metadata = Map(
       SPARK_VERSION_METADATA_KEY -> SPARK_VERSION_SHORT,
       ParquetReadSupport.SPARK_METADATA_KEY -> schemaString
@@ -132,13 +146,23 @@ class ParquetWriteSupport extends 
WriteSupport[InternalRow] with Logging {
       }
     }
 
-    logDebug(
-      s"""Initialized Parquet WriteSupport with Catalyst schema:
-         |${schema.prettyJson}
-         |and corresponding Parquet message type:
-         |$messageType
-       """.stripMargin)
-
+    if (shreddedSchemaString == null) {
+      logDebug(
+        s"""Initialized Parquet WriteSupport with Catalyst schema:
+           |${schema.prettyJson}
+           |and corresponding Parquet message type:
+           |$messageType
+         """.stripMargin)
+    } else {
+      logDebug(
+        s"""Initialized Parquet WriteSupport with Catalyst schema:
+           |${schema.prettyJson}
+           |and shredding schema:
+           |$shreddedSchema.prettyJson}
+           |and corresponding Parquet message type:
+           |$messageType
+         """.stripMargin)
+    }
     new WriteContext(messageType, metadata.asJava)
   }
 
@@ -148,7 +172,7 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] 
with Logging {
 
   override def write(row: InternalRow): Unit = {
     consumeMessage {
-      writeFields(row, schema, rootFieldWriters)
+      writeFields(row, shreddedSchema, rootFieldWriters)
     }
   }
 
@@ -250,6 +274,17 @@ class ParquetWriteSupport extends 
WriteSupport[InternalRow] with Logging {
             }
           }
 
+      case s: StructType if SparkShreddingUtils.isVariantShreddingStruct(s) =>
+        val fieldWriters = 
s.map(_.dataType).map(makeWriter).toArray[ValueWriter]
+        val variantShreddingSchema = SparkShreddingUtils.buildVariantSchema(s)
+        (row: SpecializedGetters, ordinal: Int) =>
+          val v = row.getVariant(ordinal)
+          val variant = new Variant(v.getValue, v.getMetadata)
+          val shreddedValues = SparkShreddingUtils.castShredded(variant, 
variantShreddingSchema)
+          consumeGroup {
+            writeFields(shreddedValues, s, fieldWriters)
+          }
+
       case t: StructType =>
         val fieldWriters = 
t.map(_.dataType).map(makeWriter).toArray[ValueWriter]
         (row: SpecializedGetters, ordinal: Int) =>
@@ -499,11 +534,19 @@ class ParquetWriteSupport extends 
WriteSupport[InternalRow] with Logging {
 
 object ParquetWriteSupport {
   val SPARK_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.attributes"
+  // A version of `SPARK_ROW_SCHEMA`, where one or more Variant attributes 
have been replace with a
+  // shredded struct schema.
+  val SPARK_VARIANT_SHREDDING_SCHEMA: String =
+    "org.apache.spark.sql.parquet.variant.shredding.attributes"
 
-  def setSchema(schema: StructType, configuration: Configuration): Unit = {
+  def setSchema(schema: StructType, configuration: Configuration,
+      shreddingSchema: Option[StructType]): Unit = {
     configuration.set(SPARK_ROW_SCHEMA, schema.json)
     configuration.setIfUnset(
       ParquetOutputFormat.WRITER_VERSION,
       ParquetProperties.WriterVersion.PARQUET_1_0.toString)
+    shreddingSchema.foreach { s =>
+      configuration.set(SPARK_VARIANT_SHREDDING_SCHEMA, s.json)
+    }
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
index 34c167aea363..c0c490034415 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
@@ -448,6 +448,8 @@ case object SparkShreddingUtils {
   val TypedValueFieldName = "typed_value";
   val MetadataFieldName = "metadata";
 
+  val VARIANT_WRITE_SHREDDING_KEY: String = "__VARIANT_WRITE_SHREDDING_KEY"
+
   def buildVariantSchema(schema: DataType): VariantSchema = {
     schema match {
       case s: StructType => buildVariantSchema(s, topLevel = true)
@@ -512,6 +514,27 @@ case object SparkShreddingUtils {
     }
   }
 
+  /**
+   * Given a schema that represents a valid shredding schema (e.g. constructed 
by
+   * SparkShreddingUtils.variantShreddingSchema), add metadata to the 
top-level fields to mark it
+   * as a shredding schema for writers.
+   */
+  def addWriteShreddingMetadata(schema: StructType): StructType = {
+    val newFields = schema.fields.map { f =>
+      f.copy(metadata = new
+          MetadataBuilder()
+            .withMetadata(f.metadata)
+            .putNull(VARIANT_WRITE_SHREDDING_KEY).build())
+    }
+    StructType(newFields)
+  }
+
+  // Check if the struct is marked with metadata set by 
addWriteShreddingMetadata - i.e. it
+  // represents a Variant converted to a shredding schema for writing.
+  def isVariantShreddingStruct(s: StructType): Boolean = {
+    s.fields.length > 0 && 
s.fields.forall(_.metadata.contains(VARIANT_WRITE_SHREDDING_KEY))
+  }
+
   /*
    * Given a Spark schema that represents a valid shredding schema (e.g. 
constructed by
    * SparkShreddingUtils.variantShreddingSchema), return the corresponding 
VariantSchema.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala
new file mode 100644
index 000000000000..8bb5a4b1d0bc
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import java.io.File
+
+import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.unsafe.types.VariantVal
+
+/**
+ * Test shredding Variant values in the Parquet reader/writer.
+ */
+class ParquetVariantShreddingSuite extends QueryTest with ParquetTest with 
SharedSparkSession {
+
+  private def testWithTempDir(name: String)(block: File => Unit): Unit = 
test(name) {
+    withTempDir { dir =>
+      block(dir)
+    }
+  }
+
+  testWithTempDir("write shredded variant basic") { dir =>
+    val schema = "a int, b string, c decimal(15, 1)"
+    val df = spark.sql(
+      """
+        | select case
+        | when id = 0 then parse_json('{"a": 1, "b": "2", "c": 3.3, "d": 4.4}')
+        | when id = 1 then parse_json('{"a": [1,2,3], "b": "hello", "c": {"x": 
0}}')
+        | when id = 2 then parse_json('{"A": 1, "c": 1.23}')
+        | end v from range(3)
+        |""".stripMargin)
+    val fullSchema = "v struct<metadata binary, value binary, typed_value 
struct<" +
+      "a struct<value binary, typed_value int>, b struct<value binary, 
typed_value string>," +
+      "c struct<value binary, typed_value decimal(15, 1)>>>"
+    withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> true.toString,
+      SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> schema) {
+      df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+
+
+      // Verify that we can read the full variant. The exact binary layout can 
change before and
+      // after shredding, so just check that the JSON representation matches.
+      checkAnswer(
+        spark.read.parquet(dir.getAbsolutePath).selectExpr("to_json(v)"),
+        df.selectExpr("to_json(v)").collect()
+      )
+
+      // Verify that it was shredded to the expected fields.
+
+      val shreddedDf = 
spark.read.schema(fullSchema).parquet(dir.getAbsolutePath)
+      // Metadata should be unchanaged.
+      checkAnswer(shreddedDf.selectExpr("v.metadata"),
+        df.collect().map(v => 
Row(v.get(0).asInstanceOf[VariantVal].getMetadata))
+      )
+
+      // Check typed values.
+      // Second row is not an integer, and third is A, not a
+      checkAnswer(
+        shreddedDf.selectExpr("v.typed_value.a.typed_value"),
+        Seq(Row(1), Row(null), Row(null)))
+      // b is missing from third row.
+      checkAnswer(
+        shreddedDf.selectExpr("v.typed_value.b.typed_value"),
+        Seq(Row("2"), Row("hello"), Row(null)))
+      // Second row is an object, third is the wrong scale. (Note: we may 
eventually allow the
+      // latter, in which case this test should be updated.)
+      checkAnswer(
+        shreddedDf.selectExpr("v.typed_value.c.typed_value"),
+        Seq(Row(3.3), Row(null), Row(null)))
+
+      // Untyped values are more awkward to check, so for now just check their 
nullness. We
+      // can do more thorough checking once the reader is ready.
+      checkAnswer(
+        shreddedDf.selectExpr("v.value is null"),
+        // First row has "d" and third has "A".
+        Seq(Row(false), Row(true), Row(false)))
+      checkAnswer(
+        shreddedDf.selectExpr("v.typed_value.a.value is null"),
+        // First row is fully shredded, third is missing.
+        Seq(Row(true), Row(false), Row(true)))
+      checkAnswer(
+        shreddedDf.selectExpr("v.typed_value.b.value is null"),
+        // b is always fully shredded or missing.
+        Seq(Row(true), Row(true), Row(true)))
+      checkAnswer(
+        shreddedDf.selectExpr("v.typed_value.c.value is null"),
+        Seq(Row(true), Row(false), Row(false)))
+      // The a/b/c levels are not null, even if the field is missing.
+      checkAnswer(
+        shreddedDf.selectExpr(
+          "v.typed_value.a is null or v.typed_value.b is null or 
v.typed_value.c is null"),
+        Seq(Row(false), Row(false), Row(false)))
+    }
+  }
+
+  testWithTempDir("write shredded variant array") { dir =>
+    val schema = "array<int>"
+    val df = spark.sql(
+      """
+        | select case
+        | when id = 0 then parse_json('[1, "2", 3.5, null, 5]')
+        | when id = 1 then parse_json('{"a": [1, 2, 3]}')
+        | when id = 2 then parse_json('1')
+        | when id = 3 then parse_json('null')
+        | end v from range(4)
+        |""".stripMargin)
+    val fullSchema = "v struct<metadata binary, value binary, typed_value 
array<" +
+      "struct<value binary, typed_value int>>>"
+    withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> true.toString,
+      SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> schema) {
+      df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+
+      // Verify that we can read the full variant.
+      checkAnswer(
+        spark.read.parquet(dir.getAbsolutePath).selectExpr("to_json(v)"),
+        df.selectExpr("to_json(v)").collect()
+      )
+
+      // Verify that it was shredded to the expected fields.
+
+      val shreddedDf = 
spark.read.schema(fullSchema).parquet(dir.getAbsolutePath)
+      // Metadata should be unchanaged.
+      checkAnswer(shreddedDf.selectExpr("v.metadata"),
+        df.collect().map(v => 
Row(v.get(0).asInstanceOf[VariantVal].getMetadata))
+      )
+
+      // Check typed values.
+      checkAnswer(
+        shreddedDf.selectExpr("v.typed_value.typed_value"),
+        Seq(Row(Array(1, null, null, null, 5)), Row(null), Row(null), 
Row(null)))
+
+      // All the other array elements should have non-null value.
+      checkAnswer(
+        shreddedDf.selectExpr("transform(v.typed_value.value, v -> v is 
null)"),
+        Seq(Row(Array(true, false, false, false, true)), Row(null), Row(null), 
Row(null)))
+
+      // The non-arrays should have non-null top-level value.
+      checkAnswer(
+        shreddedDf.selectExpr("v.value is null"),
+        Seq(Row(true), Row(false), Row(false), Row(false)))
+    }
+  }
+
+  testWithTempDir("write no shredding schema") { dir =>
+    // Check that we can write and read normally when shredding is enabled if
+    // we don't provide a shredding schema.
+    withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> true.toString) {
+      val df = spark.sql(
+        """
+          | select parse_json('{"a": ' || id || ', "b": 2}') as v,
+          | array(parse_json('{"c": 3}'), 123::variant) as a
+          | from range(1, 3, 1, 1)
+          |""".stripMargin)
+      df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+      checkAnswer(
+        spark.read.parquet(dir.getAbsolutePath), df.collect()
+      )
+    }
+  }
+
+  testWithTempDir("arrays and maps ignore shredding schema") { dir =>
+    // Check that we don't try to shred array or map elements, even if a 
shredding schema
+    // is specified.
+    val schema = "a int"
+    val df = spark.sql(
+      """ select v, array(v) as arr, map('myKey', v) as m from
+        | (select parse_json('{"a":' || id || '}') v from range(3))
+        |""".stripMargin)
+    val fullSchema = "v struct<metadata binary, value binary, typed_value 
struct<" +
+      "a struct<value binary, typed_value int>>>, " +
+      "arr array<struct<metadata binary, value binary>>, " +
+      "m map<string, struct<metadata binary, value binary>>"
+    withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> true.toString,
+      SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> schema) {
+      df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+
+      // Verify that we can read the full variant.
+      checkAnswer(
+        spark.read.parquet(dir.getAbsolutePath).selectExpr("to_json(v)"),
+        df.selectExpr("to_json(v)").collect()
+      )
+
+      // Verify that it was shredded to the expected fields.
+
+      val shreddedDf = 
spark.read.schema(fullSchema).parquet(dir.getAbsolutePath)
+      // Metadata should be unchanaged.
+      checkAnswer(shreddedDf.selectExpr("v.metadata"),
+        df.selectExpr("v").collect().map(v => 
Row(v.get(0).asInstanceOf[VariantVal].getMetadata))
+      )
+      checkAnswer(shreddedDf.selectExpr("arr[0].metadata"),
+        df.selectExpr("arr[0]").collect().map(v =>
+          Row(v.get(0).asInstanceOf[VariantVal].getMetadata))
+      )
+      checkAnswer(shreddedDf.selectExpr("m['myKey'].metadata"),
+        df.selectExpr("m['myKey']").collect().map(
+          v => Row(v.get(0).asInstanceOf[VariantVal].getMetadata))
+      )
+
+      // v should be fully shredded, but the array and map should not be.
+      checkAnswer(
+        shreddedDf.selectExpr(
+          "v.value is null"),
+        Seq(Row(true), Row(true), Row(true)))
+      checkAnswer(
+        shreddedDf.selectExpr(
+          "arr[0].value is null"),
+        Seq(Row(false), Row(false), Row(false)))
+      checkAnswer(
+        shreddedDf.selectExpr(
+          "m['myKey'].value is null"),
+        Seq(Row(false), Row(false), Row(false)))
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to