This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2c1c4d2614ae [SPARK-50644][SQL] Read variant struct in Parquet reader
2c1c4d2614ae is described below

commit 2c1c4d2614ae1ff902c244209f7ec3c79102d3e0
Author: Chenhao Li <[email protected]>
AuthorDate: Tue Dec 24 14:54:02 2024 +0800

    [SPARK-50644][SQL] Read variant struct in Parquet reader
    
    ### What changes were proposed in this pull request?
    
    It adds support for variant struct in Parquet reader. The concept of 
variant struct was introduced in https://github.com/apache/spark/pull/49235. It 
includes all the extracted fields from a variant column that the query requests.
    
    ### Why are the changes needed?
    
    By producing variant struct in Parquet reader, we can avoid 
reading/rebuilding the full variant and achieve more efficient variant 
processing.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit test.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49263 from chenhao-db/spark_variant_struct_reader.
    
    Authored-by: Chenhao Li <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../apache/spark/types/variant/ShreddingUtils.java |   9 +-
 .../apache/spark/types/variant/VariantSchema.java  |   6 +
 .../datasources/parquet/ParquetColumnVector.java   |  24 +-
 .../datasources/parquet/ParquetReadSupport.scala   |   9 +
 .../datasources/parquet/ParquetRowConverter.scala  |  26 +-
 .../parquet/ParquetSchemaConverter.scala           |   4 +
 .../datasources/parquet/SparkShreddingUtils.scala  | 597 ++++++++++++++++++++-
 .../apache/spark/sql/VariantShreddingSuite.scala   | 185 ++++++-
 8 files changed, 820 insertions(+), 40 deletions(-)

diff --git 
a/common/variant/src/main/java/org/apache/spark/types/variant/ShreddingUtils.java
 
b/common/variant/src/main/java/org/apache/spark/types/variant/ShreddingUtils.java
index 59e16b77ab01..6a04bf9a2b25 100644
--- 
a/common/variant/src/main/java/org/apache/spark/types/variant/ShreddingUtils.java
+++ 
b/common/variant/src/main/java/org/apache/spark/types/variant/ShreddingUtils.java
@@ -49,9 +49,8 @@ public class ShreddingUtils {
       throw malformedVariant();
     }
     byte[] metadata = row.getBinary(schema.topLevelMetadataIdx);
-    if (schema.variantIdx >= 0 && schema.typedIdx < 0) {
-      // The variant is unshredded. We are not required to do anything 
special, but we can have an
-      // optimization to avoid `rebuild`.
+    if (schema.isUnshredded()) {
+      // `rebuild` is unnecessary for unshredded variant.
       if (row.isNullAt(schema.variantIdx)) {
         throw malformedVariant();
       }
@@ -65,8 +64,8 @@ public class ShreddingUtils {
   // Rebuild a variant value from the shredded data according to the 
reconstruction algorithm in
   // https://github.com/apache/parquet-format/blob/master/VariantShredding.md.
   // Append the result to `builder`.
-  private static void rebuild(ShreddedRow row, byte[] metadata, VariantSchema 
schema,
-                              VariantBuilder builder) {
+  public static void rebuild(ShreddedRow row, byte[] metadata, VariantSchema 
schema,
+                             VariantBuilder builder) {
     int typedIdx = schema.typedIdx;
     int variantIdx = schema.variantIdx;
     if (typedIdx >= 0 && !row.isNullAt(typedIdx)) {
diff --git 
a/common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java
 
b/common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java
index 551e46214859..d1e6cc3a727f 100644
--- 
a/common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java
+++ 
b/common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java
@@ -138,6 +138,12 @@ public class VariantSchema {
     this.arraySchema = arraySchema;
   }
 
+  // Return whether the variant column is unshrededed. The user is not 
required to do anything
+  // special, but can have certain optimizations for unshrededed variant.
+  public boolean isUnshredded() {
+    return topLevelMetadataIdx >= 0 && variantIdx >= 0 && typedIdx < 0;
+  }
+
   @Override
   public String toString() {
     return "VariantSchema{" +
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
index 0b9a25fc46a0..7fb8be7caf28 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
@@ -35,7 +35,6 @@ import org.apache.spark.sql.types.MapType;
 import org.apache.spark.sql.types.StructType;
 import org.apache.spark.sql.types.VariantType;
 import org.apache.spark.types.variant.VariantSchema;
-import org.apache.spark.unsafe.types.VariantVal;
 
 /**
  * Contains necessary information representing a Parquet column, either of 
primitive or nested type.
@@ -49,6 +48,9 @@ final class ParquetColumnVector {
   // contains only one child that reads the underlying file content. This 
`ParquetColumnVector`
   // should assemble Spark variant values from the file content.
   private VariantSchema variantSchema;
+  // Only meaningful if `variantSchema` is not null. See 
`SparkShreddingUtils.getFieldsToExtract`
+  // for its meaning.
+  private FieldToExtract[] fieldsToExtract;
 
   /**
    * Repetition & Definition levels
@@ -117,6 +119,7 @@ final class ParquetColumnVector {
           fileContent, capacity, memoryMode, missingColumns, false, null);
       children.add(contentVector);
       variantSchema = 
SparkShreddingUtils.buildVariantSchema(fileContentCol.sparkType());
+      fieldsToExtract = 
SparkShreddingUtils.getFieldsToExtract(column.sparkType(), variantSchema);
       repetitionLevels = contentVector.repetitionLevels;
       definitionLevels = contentVector.definitionLevels;
     } else if (isPrimitive) {
@@ -188,20 +191,11 @@ final class ParquetColumnVector {
     if (variantSchema != null) {
       children.get(0).assemble();
       WritableColumnVector fileContent = children.get(0).getValueVector();
-      int numRows = fileContent.getElementsAppended();
-      vector.reset();
-      vector.reserve(numRows);
-      WritableColumnVector valueChild = vector.getChild(0);
-      WritableColumnVector metadataChild = vector.getChild(1);
-      for (int i = 0; i < numRows; ++i) {
-        if (fileContent.isNullAt(i)) {
-          vector.appendStruct(true);
-        } else {
-          vector.appendStruct(false);
-          VariantVal v = SparkShreddingUtils.rebuild(fileContent.getStruct(i), 
variantSchema);
-          valueChild.appendByteArray(v.getValue(), 0, v.getValue().length);
-          metadataChild.appendByteArray(v.getMetadata(), 0, 
v.getMetadata().length);
-        }
+      if (fieldsToExtract == null) {
+        SparkShreddingUtils.assembleVariantBatch(fileContent, vector, 
variantSchema);
+      } else {
+        SparkShreddingUtils.assembleVariantStructBatch(fileContent, vector, 
variantSchema,
+            fieldsToExtract);
       }
       return;
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala
index 8dde02a4673f..af0bf0d51f07 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala
@@ -35,6 +35,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
 import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.datasources.VariantMetadata
 import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
 import org.apache.spark.sql.types._
 
@@ -221,6 +222,9 @@ object ParquetReadSupport extends Logging {
         clipParquetMapType(
           parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive, 
useFieldId)
 
+      case t: StructType if VariantMetadata.isVariantStruct(t) =>
+        clipVariantSchema(parquetType.asGroupType(), t)
+
       case t: StructType =>
         clipParquetGroup(parquetType.asGroupType(), t, caseSensitive, 
useFieldId)
 
@@ -390,6 +394,11 @@ object ParquetReadSupport extends Logging {
       .named(parquetRecord.getName)
   }
 
+  private def clipVariantSchema(parquetType: GroupType, variantStruct: 
StructType): GroupType = {
+    // TODO(SHREDDING): clip `parquetType` to retain the necessary columns.
+    parquetType
+  }
+
   /**
    * Clips a Parquet [[GroupType]] which corresponds to a Catalyst 
[[StructType]].
    *
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
index 3ed7fe37ccd9..550c2af43a70 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
@@ -40,7 +40,7 @@ import 
org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
 import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.execution.datasources.DataSourceUtils
+import org.apache.spark.sql.execution.datasources.{DataSourceUtils, 
VariantMetadata}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
@@ -498,6 +498,9 @@ private[parquet] class ParquetRowConverter(
       case t: MapType =>
         new ParquetMapConverter(parquetType.asGroupType(), t, updater)
 
+      case t: StructType if VariantMetadata.isVariantStruct(t) =>
+        new ParquetVariantConverter(t, parquetType.asGroupType(), updater)
+
       case t: StructType =>
         val wrappedUpdater = {
           // SPARK-30338: avoid unnecessary InternalRow copying for nested 
structs:
@@ -536,12 +539,7 @@ private[parquet] class ParquetRowConverter(
 
       case t: VariantType =>
         if (SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_READING_SHREDDED)) {
-          // Infer a Spark type from `parquetType`. This piece of code is 
copied from
-          // `ParquetArrayConverter`.
-          val messageType = 
Types.buildMessage().addField(parquetType).named("foo")
-          val column = new ColumnIOFactory().getColumnIO(messageType)
-          val parquetSparkType = 
schemaConverter.convertField(column.getChild(0)).sparkType
-          new ParquetVariantConverter(parquetType.asGroupType(), 
parquetSparkType, updater)
+          new ParquetVariantConverter(t, parquetType.asGroupType(), updater)
         } else {
           new ParquetUnshreddedVariantConverter(parquetType.asGroupType(), 
updater)
         }
@@ -909,13 +907,14 @@ private[parquet] class ParquetRowConverter(
 
   /** Parquet converter for Variant (shredded or unshredded) */
   private final class ParquetVariantConverter(
-     parquetType: GroupType,
-     parquetSparkType: DataType,
-     updater: ParentContainerUpdater)
+     targetType: DataType, parquetType: GroupType, updater: 
ParentContainerUpdater)
     extends ParquetGroupConverter(updater) {
 
     private[this] var currentRow: Any = _
+    private[this] val parquetSparkType = 
SparkShreddingUtils.parquetTypeToSparkType(parquetType)
     private[this] val variantSchema = 
SparkShreddingUtils.buildVariantSchema(parquetSparkType)
+    private[this] val fieldsToExtract =
+      SparkShreddingUtils.getFieldsToExtract(targetType, variantSchema)
     // A struct converter that reads the underlying file data.
     private[this] val fileConverter = new ParquetRowConverter(
       schemaConverter,
@@ -932,7 +931,12 @@ private[parquet] class ParquetRowConverter(
 
     override def end(): Unit = {
       fileConverter.end()
-      val v = 
SparkShreddingUtils.rebuild(currentRow.asInstanceOf[InternalRow], variantSchema)
+      val row = currentRow.asInstanceOf[InternalRow]
+      val v = if (fieldsToExtract == null) {
+        SparkShreddingUtils.assembleVariant(row, variantSchema)
+      } else {
+        SparkShreddingUtils.assembleVariantStruct(row, variantSchema, 
fieldsToExtract)
+      }
       updater.set(v)
     }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
index 7f1b49e73790..64c2a3126ca9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
@@ -28,6 +28,7 @@ import org.apache.parquet.schema.Type.Repetition._
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.execution.datasources.VariantMetadata
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
@@ -185,6 +186,9 @@ class ParquetToSparkSchemaConverter(
         } else {
           convertVariantField(groupColumn)
         }
+      case groupColumn: GroupColumnIO if 
targetType.exists(VariantMetadata.isVariantStruct) =>
+        val col = convertGroupField(groupColumn)
+        col.copy(sparkType = targetType.get, variantFileType = Some(col))
       case groupColumn: GroupColumnIO => convertGroupField(groupColumn, 
targetType)
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
index f38e188ed042..a83ca78455fa 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
@@ -17,12 +17,23 @@
 
 package org.apache.spark.sql.execution.datasources.parquet
 
+import org.apache.parquet.io.ColumnIOFactory
+import org.apache.parquet.schema.{Type => ParquetType, Types => ParquetTypes}
+
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
-import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.expressions.variant._
+import 
org.apache.spark.sql.catalyst.expressions.variant.VariantPathParser.PathSegment
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, 
DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
+import org.apache.spark.sql.execution.RowToColumnConverter
+import org.apache.spark.sql.execution.datasources.VariantMetadata
+import org.apache.spark.sql.execution.vectorized.WritableColumnVector
 import org.apache.spark.sql.types._
 import org.apache.spark.types.variant._
+import org.apache.spark.types.variant.VariantUtil.Type
 import org.apache.spark.unsafe.types._
 
 case class SparkShreddedRow(row: SpecializedGetters) extends 
ShreddingUtils.ShreddedRow {
@@ -45,6 +56,369 @@ case class SparkShreddedRow(row: SpecializedGetters) 
extends ShreddingUtils.Shre
   override def numElements(): Int = row.asInstanceOf[ArrayData].numElements()
 }
 
+// The search result of a `PathSegment` in a `VariantSchema`.
+case class SchemaPathSegment(
+    rawPath: PathSegment,
+    // Whether this path segment is an object or array extraction.
+    isObject: Boolean,
+    // `schema.typedIdx`, if the path exists in the schema (for object 
extraction, the schema
+    // should contain an object `typed_value` containing the requested field; 
similar for array
+    // extraction). Negative otherwise.
+    typedIdx: Int,
+    // For object extraction, it is the index of the desired field in 
`schema.objectSchema`. If the
+    // requested field doesn't exist, both `extractionIdx/typedIdx` are set to 
negative.
+    // For array extraction, it is the array index. The information is already 
stored in `rawPath`,
+    // but accessing a raw int should be more efficient than `rawPath`, which 
is an `Either`.
+    extractionIdx: Int)
+
+// Represent a single field in a variant struct (see `VariantMetadata` for 
definition), that is, a
+// single requested field that the scan should produce by extracting from the 
variant column.
+case class FieldToExtract(path: Array[SchemaPathSegment], reader: 
ParquetVariantReader)
+
+// A helper class to cast from scalar `typed_value` into a scalar `dataType`. 
Need a custom
+// expression because it has different error reporting code than `Cast`.
+case class ScalarCastHelper(
+    child: Expression,
+    dataType: DataType,
+    castArgs: VariantCastArgs) extends UnaryExpression {
+  // The expression is only for the internal use of `ScalarReader`, which can 
guarantee the child
+  // is not nullable.
+  assert(!child.nullable)
+
+  // If `cast` is null, it means the cast always fails because the type 
combination is not allowed.
+  private val cast = if (Cast.canAnsiCast(child.dataType, dataType)) {
+    Cast(child, dataType, castArgs.zoneStr, EvalMode.TRY)
+  } else {
+    null
+  }
+  // Cast the input to string. Only used for reporting an invalid cast.
+  private val castToString = Cast(child, StringType, castArgs.zoneStr, 
EvalMode.ANSI)
+
+  override def nullable: Boolean = !castArgs.failOnError
+  override def withNewChildInternal(newChild: Expression): UnaryExpression = 
copy(child = newChild)
+
+  // No need to define the interpreted version of `eval`: the codegen must 
succeed.
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
+    // Throw an error or do nothing, depending on `castArgs.failOnError`.
+    val invalidCastCode = if (castArgs.failOnError) {
+      val castToStringCode = castToString.genCode(ctx)
+      val typeObj = ctx.addReferenceObj("dataType", dataType)
+      val cls = classOf[ScalarCastHelper].getName
+      s"""
+        ${castToStringCode.code}
+        $cls.throwInvalidVariantCast(${castToStringCode.value}, $typeObj);
+      """
+    } else {
+      ""
+    }
+    if (cast != null) {
+      val castCode = cast.genCode(ctx)
+      val code = code"""
+        ${castCode.code}
+        boolean ${ev.isNull} = ${castCode.isNull};
+        ${CodeGenerator.javaType(dataType)} ${ev.value} = ${castCode.value};
+        if (${ev.isNull}) { $invalidCastCode }
+      """
+      ev.copy(code = code)
+    } else {
+      val code = code"""
+        boolean ${ev.isNull} = true;
+        ${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
+        if (${ev.isNull}) { $invalidCastCode }
+      """
+      ev.copy(code = code)
+    }
+  }
+}
+
+object ScalarCastHelper {
+  // A helper function for codegen. The java compiler doesn't allow throwing a 
`Throwable` in a
+  // method without `throws` annotation.
+  def throwInvalidVariantCast(value: UTF8String, dataType: DataType): Any =
+    throw QueryExecutionErrors.invalidVariantCast(value.toString, dataType)
+}
+
+// The base class to read Parquet variant values into a Spark type.
+// For convenience, we also allow creating an instance of the base class 
itself. None of its
+// functions can be used, but it can serve as a container of `targetType` and 
`castArgs`.
+class ParquetVariantReader(
+    val schema: VariantSchema, val targetType: DataType, val castArgs: 
VariantCastArgs) {
+  // Read from a row containing a Parquet variant value (shredded or 
unshredded) and return a value
+  // of `targetType`. The row schema is described by `schema`.
+  // This function throws MALFORMED_VARIANT if the variant is missing. If the 
variant can be
+  // legally missing (the only possible situation is struct fields in object 
`typed_value`), the
+  // caller should check for it and avoid calling this function if the variant 
is missing.
+  def read(row: InternalRow, topLevelMetadata: Array[Byte]): Any = {
+    if (schema.typedIdx < 0 || row.isNullAt(schema.typedIdx)) {
+      if (schema.variantIdx < 0 || row.isNullAt(schema.variantIdx)) {
+        // Both `typed_value` and `value` are null, meaning the variant is 
missing.
+        throw QueryExecutionErrors.malformedVariant()
+      }
+      val v = new Variant(row.getBinary(schema.variantIdx), topLevelMetadata)
+      VariantGet.cast(v, targetType, castArgs)
+    } else {
+      readFromTyped(row, topLevelMetadata)
+    }
+  }
+
+  // Subclasses should override it to produce the read result when 
`typed_value` is not null.
+  protected def readFromTyped(row: InternalRow, topLevelMetadata: 
Array[Byte]): Any =
+    throw QueryExecutionErrors.unreachableError()
+
+  // A util function to rebuild the variant in binary format from a Parquet 
variant value.
+  protected final def rebuildVariant(row: InternalRow, topLevelMetadata: 
Array[Byte]): Variant = {
+    val builder = new VariantBuilder(false)
+    ShreddingUtils.rebuild(SparkShreddedRow(row), topLevelMetadata, schema, 
builder)
+    builder.result()
+  }
+
+  // A util function to throw error or return null when an invalid cast 
happens.
+  protected final def invalidCast(row: InternalRow, topLevelMetadata: 
Array[Byte]): Any = {
+    if (castArgs.failOnError) {
+      throw QueryExecutionErrors.invalidVariantCast(
+        rebuildVariant(row, topLevelMetadata).toJson(castArgs.zoneId), 
targetType)
+    } else {
+      null
+    }
+  }
+}
+
+object ParquetVariantReader {
+  // Create a reader for `targetType`. If `schema` is null, meaning that the 
extraction path doesn't
+  // exist in `typed_value`, it returns an instance of `ParquetVariantReader`. 
As described in the
+  // class comment, the reader is only a container of `targetType` and 
`castArgs` in this case.
+  def apply(schema: VariantSchema, targetType: DataType, castArgs: 
VariantCastArgs,
+            isTopLevelUnshredded: Boolean = false): ParquetVariantReader = 
targetType match {
+    case _ if schema == null => new ParquetVariantReader(schema, targetType, 
castArgs)
+    case s: StructType => new StructReader(schema, s, castArgs)
+    case a: ArrayType => new ArrayReader(schema, a, castArgs)
+    case m@MapType(_: StringType, _, _) => new MapReader(schema, m, castArgs)
+    case v: VariantType => new VariantReader(schema, v, castArgs, 
isTopLevelUnshredded)
+    case s: AtomicType => new ScalarReader(schema, s, castArgs)
+    case _ =>
+      // Type check should have rejected map with non-string type.
+      throw QueryExecutionErrors.unreachableError(s"Invalid target type: 
`${targetType.sql}`")
+  }
+}
+
+// Read Parquet variant values into a Spark struct type. It reads unshredded 
fields (fields that are
+// not in the typed object) from the `value`, and reads the shredded fields 
from the object
+// `typed_value`.
+// `value` must not contain any shredded field according to the shredding 
spec, but this requirement
+// is not enforced. If `value` does contain a shredded field, no error will 
occur, and the field in
+// object `typed_value` will be the final result.
+private[this] final class StructReader(
+  schema: VariantSchema, targetType: StructType, castArgs: VariantCastArgs)
+  extends ParquetVariantReader(schema, targetType, castArgs) {
+  // For each field in `targetType`, store the index of the field with the 
same name in object
+  // `typed_value`, or -1 if it doesn't exist in object `typed_value`.
+  private[this] val fieldInputIndices: Array[Int] = targetType.fields.map { f 
=>
+    val inputIdx = if (schema.objectSchemaMap != null) 
schema.objectSchemaMap.get(f.name) else null
+    if (inputIdx != null) inputIdx.intValue() else -1
+  }
+  // For each field in `targetType`, store the reader from the corresponding 
field in object
+  // `typed_value`, or null if it doesn't exist in object `typed_value`.
+  private[this] val fieldReaders: Array[ParquetVariantReader] =
+    targetType.fields.zip(fieldInputIndices).map { case (f, inputIdx) =>
+      if (inputIdx >= 0) {
+        val fieldSchema = schema.objectSchema(inputIdx).schema
+        ParquetVariantReader(fieldSchema, f.dataType, castArgs)
+      } else {
+        null
+      }
+    }
+  // If all fields in `targetType` can be found in object `typed_value`, then 
the reader doesn't
+  // need to read from `value`.
+  private[this] val needUnshreddedObject: Boolean = fieldInputIndices.exists(_ 
< 0)
+
+  override def readFromTyped(row: InternalRow, topLevelMetadata: Array[Byte]): 
Any = {
+    if (schema.objectSchema == null) return invalidCast(row, topLevelMetadata)
+    val obj = row.getStruct(schema.typedIdx, schema.objectSchema.length)
+    val result = new GenericInternalRow(fieldInputIndices.length)
+    var unshreddedObject: Variant = null
+    if (needUnshreddedObject && schema.variantIdx >= 0 && 
!row.isNullAt(schema.variantIdx)) {
+      unshreddedObject = new Variant(row.getBinary(schema.variantIdx), 
topLevelMetadata)
+      if (unshreddedObject.getType != Type.OBJECT) throw 
QueryExecutionErrors.malformedVariant()
+    }
+    val numFields = fieldInputIndices.length
+    var i = 0
+    while (i < numFields) {
+      val inputIdx = fieldInputIndices(i)
+      if (inputIdx >= 0) {
+        // Shredded field must not be null.
+        if (obj.isNullAt(inputIdx)) throw 
QueryExecutionErrors.malformedVariant()
+        val fieldSchema = schema.objectSchema(inputIdx).schema
+        val fieldInput = obj.getStruct(inputIdx, fieldSchema.numFields)
+        // Only read from the shredded field if it is not missing.
+        if ((fieldSchema.typedIdx >= 0 && 
!fieldInput.isNullAt(fieldSchema.typedIdx)) ||
+          (fieldSchema.variantIdx >= 0 && 
!fieldInput.isNullAt(fieldSchema.variantIdx))) {
+          result.update(i, fieldReaders(i).read(fieldInput, topLevelMetadata))
+        }
+      } else if (unshreddedObject != null) {
+        val fieldName = targetType.fields(i).name
+        val fieldType = targetType.fields(i).dataType
+        val unshreddedField = unshreddedObject.getFieldByKey(fieldName)
+        if (unshreddedField != null) {
+          result.update(i, VariantGet.cast(unshreddedField, fieldType, 
castArgs))
+        }
+      }
+      i += 1
+    }
+    result
+  }
+}
+
+// Read Parquet variant values into a Spark array type.
+private[this] final class ArrayReader(
+    schema: VariantSchema, targetType: ArrayType, castArgs: VariantCastArgs)
+  extends ParquetVariantReader(schema, targetType, castArgs) {
+  private[this] val elementReader = if (schema.arraySchema != null) {
+    ParquetVariantReader(schema.arraySchema, targetType.elementType, castArgs)
+  } else {
+    null
+  }
+
+  override def readFromTyped(row: InternalRow, topLevelMetadata: Array[Byte]): 
Any = {
+    if (schema.arraySchema == null) return invalidCast(row, topLevelMetadata)
+    val elementNumFields = schema.arraySchema.numFields
+    val arr = row.getArray(schema.typedIdx)
+    val size = arr.numElements()
+    val result = new Array[Any](size)
+    var i = 0
+    while (i < size) {
+      // Shredded array element must not be null.
+      if (arr.isNullAt(i)) throw QueryExecutionErrors.malformedVariant()
+      result(i) = elementReader.read(arr.getStruct(i, elementNumFields), 
topLevelMetadata)
+      i += 1
+    }
+    new GenericArrayData(result)
+  }
+}
+
+// Read Parquet variant values into a Spark map type with string key type. The 
input must be object
+// for a valid cast. The resulting map contains shredded fields from object 
`typed_value` and
+// unshredded fields from object `value`.
+// `value` must not contain any shredded field according to the shredding 
spec. Unlike
+// `StructReader`, this requirement is enforced in `MapReader`. If `value` 
does contain a shredded
+// field, throw a MALFORMED_VARIANT error. The purpose is to avoid duplicate 
map keys.
+private[this] final class MapReader(
+    schema: VariantSchema, targetType: MapType, castArgs: VariantCastArgs)
+  extends ParquetVariantReader(schema, targetType, castArgs) {
+  // Readers that convert each shredded field into the map value type.
+  private[this] val valueReaders = if (schema.objectSchema != null) {
+    schema.objectSchema.map { f =>
+      ParquetVariantReader(f.schema, targetType.valueType, castArgs)
+    }
+  } else {
+    null
+  }
+  // `UTF8String` representation of shredded field names. Do the `String -> 
UTF8String` once, so
+  // that `readFromTyped` doesn't need to do it repeatedly.
+  private[this] val shreddedFieldNames = if (schema.objectSchema != null) {
+    schema.objectSchema.map { f => UTF8String.fromString(f.fieldName) }
+  } else {
+    null
+  }
+
+  override def readFromTyped(row: InternalRow, topLevelMetadata: Array[Byte]): 
Any = {
+    if (schema.objectSchema == null) return invalidCast(row, topLevelMetadata)
+    val obj = row.getStruct(schema.typedIdx, schema.objectSchema.length)
+    val numShreddedFields = valueReaders.length
+    var unshreddedObject: Variant = null
+    if (schema.variantIdx >= 0 && !row.isNullAt(schema.variantIdx)) {
+      unshreddedObject = new Variant(row.getBinary(schema.variantIdx), 
topLevelMetadata)
+      if (unshreddedObject.getType != Type.OBJECT) throw 
QueryExecutionErrors.malformedVariant()
+    }
+    val numUnshreddedFields = if (unshreddedObject != null) 
unshreddedObject.objectSize() else 0
+    var keyArray = new Array[UTF8String](numShreddedFields + 
numUnshreddedFields)
+    var valueArray = new Array[Any](numShreddedFields + numUnshreddedFields)
+    var mapLength = 0
+    var i = 0
+    while (i < numShreddedFields) {
+      // Shredded field must not be null.
+      if (obj.isNullAt(i)) throw QueryExecutionErrors.malformedVariant()
+      val fieldSchema = schema.objectSchema(i).schema
+      val fieldInput = obj.getStruct(i, fieldSchema.numFields)
+      // Only add the shredded field to map if it is not missing.
+      if ((fieldSchema.typedIdx >= 0 && 
!fieldInput.isNullAt(fieldSchema.typedIdx)) ||
+        (fieldSchema.variantIdx >= 0 && 
!fieldInput.isNullAt(fieldSchema.variantIdx))) {
+        keyArray(mapLength) = shreddedFieldNames(i)
+        valueArray(mapLength) = valueReaders(i).read(fieldInput, 
topLevelMetadata)
+        mapLength += 1
+      }
+      i += 1
+    }
+    i = 0
+    while (i < numUnshreddedFields) {
+      val field = unshreddedObject.getFieldAtIndex(i)
+      if (schema.objectSchemaMap.containsKey(field.key)) {
+        throw QueryExecutionErrors.malformedVariant()
+      }
+      keyArray(mapLength) = UTF8String.fromString(field.key)
+      valueArray(mapLength) = VariantGet.cast(field.value, 
targetType.valueType, castArgs)
+      mapLength += 1
+      i += 1
+    }
+    // Need to shrink the arrays if there are missing shredded fields.
+    if (mapLength < keyArray.length) {
+      keyArray = keyArray.slice(0, mapLength)
+      valueArray = valueArray.slice(0, mapLength)
+    }
+    ArrayBasedMapData(keyArray, valueArray)
+  }
+}
+
+// Read Parquet variant values into a Spark variant type (the binary format).
+private[this] final class VariantReader(
+    schema: VariantSchema, targetType: DataType, castArgs: VariantCastArgs,
+    // An optional optimization: the user can set it to true if the Parquet 
variant column is
+    // unshredded and the extraction path is empty. We are not required to do 
anything special, bu
+    // we can avoid rebuilding variant for optimization purpose.
+    private[this] val isTopLevelUnshredded: Boolean)
+  extends ParquetVariantReader(schema, targetType, castArgs) {
+  override def read(row: InternalRow, topLevelMetadata: Array[Byte]): Any = {
+    if (isTopLevelUnshredded) {
+      if (row.isNullAt(schema.variantIdx)) throw 
QueryExecutionErrors.malformedVariant()
+      return new VariantVal(row.getBinary(schema.variantIdx), topLevelMetadata)
+    }
+    val v = rebuildVariant(row, topLevelMetadata)
+    new VariantVal(v.getValue, v.getMetadata)
+  }
+}
+
+// Read Parquet variant values into a Spark scalar type. When `typed_value` is 
not null but not a
+// scalar, all other target types should return an invalid cast, but only the 
string target type can
+// still build a string from array/object `typed_value`. For scalar 
`typed_value`, it depends on
+// `ScalarCastHelper` to perform the cast.
+// According to the shredding spec, scalar `typed_value` and `value` must not 
be non-null at the
+// same time. The requirement is not enforced in this reader. If they are both 
non-null, no error
+// will occur, and the reader will read from `typed_value`.
+private[this] final class ScalarReader(
+    schema: VariantSchema, targetType: DataType, castArgs: VariantCastArgs)
+  extends ParquetVariantReader(schema, targetType, castArgs) {
+  private[this] val castProject = if (schema.scalarSchema != null) {
+    val scalarType = 
SparkShreddingUtils.scalarSchemaToSparkType(schema.scalarSchema)
+    // Read the cast input from ordinal `schema.typedIdx` in the input row. 
The cast input is never
+    // null, because `readFromTyped` is only called when `typed_value` is not 
null.
+    val input = BoundReference(schema.typedIdx, scalarType, nullable = false)
+    MutableProjection.create(Seq(ScalarCastHelper(input, targetType, 
castArgs)))
+  } else {
+    null
+  }
+
+  override def readFromTyped(row: InternalRow, topLevelMetadata: Array[Byte]): 
Any = {
+    if (castProject == null) {
+      return if (targetType.isInstanceOf[StringType]) {
+        UTF8String.fromString(rebuildVariant(row, 
topLevelMetadata).toJson(castArgs.zoneId))
+      } else {
+        invalidCast(row, topLevelMetadata)
+      }
+    }
+    val result = castProject(row)
+    if (result.isNullAt(0)) null else result.get(0, targetType)
+  }
+}
+
 case object SparkShreddingUtils {
   val VariantValueFieldName = "value";
   val TypedValueFieldName = "typed_value";
@@ -126,6 +500,11 @@ case object SparkShreddingUtils {
     var objectSchema: Array[VariantSchema.ObjectField] = null
     var arraySchema: VariantSchema = null
 
+    // The struct must not be empty or contain duplicate field names. The 
latter is enforced in the
+    // loop below (`if (typedIdx != -1)` and other similar checks).
+    if (schema.fields.isEmpty) {
+      throw QueryCompilationErrors.invalidVariantShreddingSchema(schema)
+    }
     schema.fields.zipWithIndex.foreach { case (f, i) =>
       f.name match {
         case TypedValueFieldName =>
@@ -135,8 +514,11 @@ case object SparkShreddingUtils {
           typedIdx = i
           f.dataType match {
             case StructType(fields) =>
-              objectSchema =
-                  new Array[VariantSchema.ObjectField](fields.length)
+              // The struct must not be empty or contain duplicate field names.
+              if (fields.isEmpty || fields.map(_.name).distinct.length != 
fields.length) {
+                throw 
QueryCompilationErrors.invalidVariantShreddingSchema(schema)
+              }
+              objectSchema = new 
Array[VariantSchema.ObjectField](fields.length)
               fields.zipWithIndex.foreach { case (field, fieldIdx) =>
                 field.dataType match {
                   case s: StructType =>
@@ -188,6 +570,32 @@ case object SparkShreddingUtils {
       scalarSchema, objectSchema, arraySchema)
   }
 
+  // Convert a scalar variant schema into a Spark scalar type.
+  def scalarSchemaToSparkType(scalar: VariantSchema.ScalarType): DataType = 
scalar match {
+    case _: VariantSchema.StringType => StringType
+    case it: VariantSchema.IntegralType => it.size match {
+      case VariantSchema.IntegralSize.BYTE => ByteType
+      case VariantSchema.IntegralSize.SHORT => ShortType
+      case VariantSchema.IntegralSize.INT => IntegerType
+      case VariantSchema.IntegralSize.LONG => LongType
+    }
+    case _: VariantSchema.FloatType => FloatType
+    case _: VariantSchema.DoubleType => DoubleType
+    case _: VariantSchema.BooleanType => BooleanType
+    case _: VariantSchema.BinaryType => BinaryType
+    case dt: VariantSchema.DecimalType => DecimalType(dt.precision, dt.scale)
+    case _: VariantSchema.DateType => DateType
+    case _: VariantSchema.TimestampType => TimestampType
+    case _: VariantSchema.TimestampNTZType => TimestampNTZType
+  }
+
+  // Convert a Parquet type into a Spark data type.
+  def parquetTypeToSparkType(parquetType: ParquetType): DataType = {
+    val messageType = 
ParquetTypes.buildMessage().addField(parquetType).named("foo")
+    val column = new ColumnIOFactory().getColumnIO(messageType)
+    new 
ParquetToSparkSchemaConverter().convertField(column.getChild(0)).sparkType
+  }
+
   class SparkShreddedResult(schema: VariantSchema) extends 
VariantShreddingWriter.ShreddedResult {
     // Result is stored as an InternalRow.
     val row = new GenericInternalRow(schema.numFields)
@@ -243,8 +651,187 @@ case object SparkShreddingUtils {
         .row
   }
 
-  def rebuild(row: InternalRow, schema: VariantSchema): VariantVal = {
+  // Return a list of fields to extract. `targetType` must be either variant 
or variant struct.
+  // If it is variant, return null because the target is the full variant and 
there is no field to
+  // extract. If it is variant struct, return a list of fields matching the 
variant struct fields.
+  def getFieldsToExtract(targetType: DataType, inputSchema: VariantSchema): 
Array[FieldToExtract] =
+    targetType match {
+      case _: VariantType => null
+      case s: StructType if VariantMetadata.isVariantStruct(s) =>
+        s.fields.map { f =>
+          val metadata = VariantMetadata.fromMetadata(f.metadata)
+          val rawPath = metadata.parsedPath()
+          val schemaPath = new Array[SchemaPathSegment](rawPath.length)
+          var schema = inputSchema
+          // Search `rawPath` in `schema` to produce `schemaPath`. If a raw 
path segment cannot be
+          // found at a certain level of the file type, then `typedIdx` will 
be -1 starting from
+          // this position, and the final `schema` will be null.
+          for (i <- rawPath.indices) {
+            val isObject = rawPath(i).isLeft
+            var typedIdx = -1
+            var extractionIdx = -1
+            rawPath(i) match {
+              case scala.util.Left(key) if schema != null && 
schema.objectSchema != null =>
+                val fieldIdx = schema.objectSchemaMap.get(key)
+                if (fieldIdx != null) {
+                  typedIdx = schema.typedIdx
+                  extractionIdx = fieldIdx
+                  schema = schema.objectSchema(fieldIdx).schema
+                } else {
+                  schema = null
+                }
+              case scala.util.Right(index) if schema != null && 
schema.arraySchema != null =>
+                typedIdx = schema.typedIdx
+                extractionIdx = index
+                schema = schema.arraySchema
+              case _ =>
+                schema = null
+            }
+            schemaPath(i) = SchemaPathSegment(rawPath(i), isObject, typedIdx, 
extractionIdx)
+          }
+          val reader = ParquetVariantReader(schema, f.dataType, 
VariantCastArgs(
+            metadata.failOnError,
+            Some(metadata.timeZoneId),
+            DateTimeUtils.getZoneId(metadata.timeZoneId)),
+            isTopLevelUnshredded = schemaPath.isEmpty && 
inputSchema.isUnshredded)
+          FieldToExtract(schemaPath, reader)
+        }
+      case _ =>
+        throw QueryExecutionErrors.unreachableError(s"Invalid target type: 
`${targetType.sql}`")
+    }
+
+  // Extract a single variant struct field from a Parquet variant value. It 
steps into `inputRow`
+  // according to the variant extraction path, and read the extracted value as 
the target type.
+  private def extractField(
+      inputRow: InternalRow,
+      topLevelMetadata: Array[Byte],
+      inputSchema: VariantSchema,
+      pathList: Array[SchemaPathSegment],
+      reader: ParquetVariantReader): Any = {
+    var pathIdx = 0
+    val pathLen = pathList.length
+    var row = inputRow
+    var schema = inputSchema
+    while (pathIdx < pathLen) {
+      val path = pathList(pathIdx)
+
+      if (path.typedIdx < 0) {
+        // The extraction doesn't exist in `typed_value`. Try to extract the 
remaining part of the
+        // path in `value`.
+        val variantIdx = schema.variantIdx
+        if (variantIdx < 0 || row.isNullAt(variantIdx)) return null
+        var v = new Variant(row.getBinary(variantIdx), topLevelMetadata)
+        while (pathIdx < pathLen) {
+          v = pathList(pathIdx).rawPath match {
+            case scala.util.Left(key) if v.getType == Type.OBJECT => 
v.getFieldByKey(key)
+            case scala.util.Right(index) if v.getType == Type.ARRAY => 
v.getElementAtIndex(index)
+            case _ => null
+          }
+          if (v == null) return null
+          pathIdx += 1
+        }
+        return VariantGet.cast(v, reader.targetType, reader.castArgs)
+      }
+
+      if (row.isNullAt(path.typedIdx)) return null
+      if (path.isObject) {
+        val obj = row.getStruct(path.typedIdx, schema.objectSchema.length)
+        // Object field must not be null.
+        if (obj.isNullAt(path.extractionIdx)) throw 
QueryExecutionErrors.malformedVariant()
+        schema = schema.objectSchema(path.extractionIdx).schema
+        row = obj.getStruct(path.extractionIdx, schema.numFields)
+        // Return null if the field is missing.
+        if ((schema.typedIdx < 0 || row.isNullAt(schema.typedIdx)) &&
+          (schema.variantIdx < 0 || row.isNullAt(schema.variantIdx))) {
+          return null
+        }
+      } else {
+        val arr = row.getArray(path.typedIdx)
+        // Return null if the extraction index is out of bound.
+        if (path.extractionIdx >= arr.numElements()) return null
+        // Array element must not be null.
+        if (arr.isNullAt(path.extractionIdx)) throw 
QueryExecutionErrors.malformedVariant()
+        schema = schema.arraySchema
+        row = arr.getStruct(path.extractionIdx, schema.numFields)
+      }
+      pathIdx += 1
+    }
+    reader.read(row, topLevelMetadata)
+  }
+
+  // Assemble a variant (binary format) from a Parquet variant value.
+  def assembleVariant(row: InternalRow, schema: VariantSchema): VariantVal = {
     val v = ShreddingUtils.rebuild(SparkShreddedRow(row), schema)
     new VariantVal(v.getValue, v.getMetadata)
   }
+
+  // Assemble a variant struct, in which each field is extracted from the 
Parquet variant value.
+  def assembleVariantStruct(
+      inputRow: InternalRow,
+      schema: VariantSchema,
+      fields: Array[FieldToExtract]): InternalRow = {
+    if (inputRow.isNullAt(schema.topLevelMetadataIdx)) {
+      throw QueryExecutionErrors.malformedVariant()
+    }
+    val topLevelMetadata = inputRow.getBinary(schema.topLevelMetadataIdx)
+    val numFields = fields.length
+    val resultRow = new GenericInternalRow(numFields)
+    var fieldIdx = 0
+    while (fieldIdx < numFields) {
+      resultRow.update(fieldIdx, extractField(inputRow, topLevelMetadata, 
schema,
+        fields(fieldIdx).path, fields(fieldIdx).reader))
+      fieldIdx += 1
+    }
+    resultRow
+  }
+
+  // Assemble a batch of variant (binary format) from a batch of Parquet 
variant values.
+  def assembleVariantBatch(
+      input: WritableColumnVector,
+      output: WritableColumnVector,
+      schema: VariantSchema): Unit = {
+    val numRows = input.getElementsAppended
+    output.reset()
+    output.reserve(numRows)
+    val valueChild = output.getChild(0)
+    val metadataChild = output.getChild(1)
+    var i = 0
+    while (i < numRows) {
+      if (input.isNullAt(i)) {
+        output.appendStruct(true)
+      } else {
+        output.appendStruct(false)
+        val v = SparkShreddingUtils.assembleVariant(input.getStruct(i), schema)
+        valueChild.appendByteArray(v.getValue, 0, v.getValue.length)
+        metadataChild.appendByteArray(v.getMetadata, 0, v.getMetadata.length)
+      }
+      i += 1
+    }
+  }
+
+  // Assemble a batch of variant struct from a batch of Parquet variant values.
+  def assembleVariantStructBatch(
+      input: WritableColumnVector,
+      output: WritableColumnVector,
+      schema: VariantSchema,
+      fields: Array[FieldToExtract]): Unit = {
+    val numRows = input.getElementsAppended
+    output.reset()
+    output.reserve(numRows)
+    val converter = new RowToColumnConverter(StructType(Array(StructField("", 
output.dataType()))))
+    val converterVectors = Array(output)
+    val converterRow = new GenericInternalRow(1)
+    output.reset()
+    output.reserve(input.getElementsAppended)
+    var i = 0
+    while (i < numRows) {
+      if (input.isNullAt(i)) {
+        converterRow.update(0, null)
+      } else {
+        converterRow.update(0, assembleVariantStruct(input.getStruct(i), 
schema, fields))
+      }
+      converter.convert(converterRow, converterVectors)
+      i += 1
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala
index 5d5c44105255..b6623bb57a71 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala
@@ -22,13 +22,21 @@ import java.sql.{Date, Timestamp}
 import java.time.LocalDateTime
 
 import org.apache.spark.SparkThrowable
+import org.apache.spark.sql.catalyst.InternalRow
+import 
org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
 import org.apache.spark.sql.execution.datasources.parquet.{ParquetTest, 
SparkShreddingUtils}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types._
 import org.apache.spark.types.variant._
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
 
 class VariantShreddingSuite extends QueryTest with SharedSparkSession with 
ParquetTest {
+  def parseJson(s: String): VariantVal = {
+    val v = VariantBuilder.parseJson(s, false)
+    new VariantVal(v.getValue, v.getMetadata)
+  }
+
   // Make a variant value binary by parsing a JSON string.
   def value(s: String): Array[Byte] = VariantBuilder.parseJson(s, 
false).getValue
 
@@ -53,9 +61,21 @@ class VariantShreddingSuite extends QueryTest with 
SharedSparkSession with Parqu
   def writeSchema(schema: DataType): StructType =
     StructType(Array(StructField("v", 
SparkShreddingUtils.variantShreddingSchema(schema))))
 
+  def withPushConfigs(pushConfigs: Seq[Boolean] = Seq(true, false))(fn: => 
Unit): Unit = {
+    for (push <- pushConfigs) {
+      withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN.key -> push.toString) {
+        fn
+      }
+    }
+  }
+
+  def isPushEnabled: Boolean = 
SQLConf.get.getConf(SQLConf.PUSH_VARIANT_INTO_SCAN)
+
   def testWithTempPath(name: String)(block: File => Unit): Unit = test(name) {
-    withTempPath { path =>
-      block(path)
+    withPushConfigs() {
+      withTempPath { path =>
+        block(path)
+      }
     }
   }
 
@@ -63,6 +83,9 @@ class VariantShreddingSuite extends QueryTest with 
SharedSparkSession with Parqu
     spark.createDataFrame(spark.sparkContext.parallelize(rows.map(Row(_)), 
numSlices = 1), schema)
       .write.mode("overwrite").parquet(path.getAbsolutePath)
 
+  def writeRows(path: File, schema: String, rows: Row*): Unit =
+    writeRows(path, StructType.fromDDL(schema), rows: _*)
+
   def read(path: File): DataFrame =
     spark.read.schema("v variant").parquet(path.getAbsolutePath)
 
@@ -150,10 +173,13 @@ class VariantShreddingSuite extends QueryTest with 
SharedSparkSession with Parqu
     // Top-level variant must not be missing.
     writeRows(path, writeSchema(IntegerType), Row(metadata(Nil), null, null))
     checkException(path, "v", "MALFORMED_VARIANT")
+
     // Array-element variant must not be missing.
     writeRows(path, writeSchema(ArrayType(IntegerType)),
       Row(metadata(Nil), null, Array(Row(null, null))))
     checkException(path, "v", "MALFORMED_VARIANT")
+    checkException(path, "variant_get(v, '$[0]')", "MALFORMED_VARIANT")
+
     // Shredded field must not be null.
     // Construct the schema manually, because 
SparkShreddingUtils.variantShreddingSchema will make
     // `a` non-nullable, which would prevent us from writing the file.
@@ -164,12 +190,163 @@ class VariantShreddingSuite extends QueryTest with 
SharedSparkSession with Parqu
         StructField("a", StructType(Seq(
           StructField("value", BinaryType),
           StructField("typed_value", BinaryType))))))))))))
-    writeRows(path, schema,
-      Row(metadata(Seq("a")), null, Row(null)))
+    writeRows(path, schema, Row(metadata(Seq("a")), null, Row(null)))
     checkException(path, "v", "MALFORMED_VARIANT")
+    checkException(path, "variant_get(v, '$.a')", "MALFORMED_VARIANT")
+
     // `value` must not contain any shredded field.
     writeRows(path, writeSchema(StructType.fromDDL("a int")),
       Row(metadata(Seq("a")), value("""{"a": 1}"""), Row(Row(null, null))))
     checkException(path, "v", "MALFORMED_VARIANT")
+    checkException(path, "cast(v as map<string, int>)", "MALFORMED_VARIANT")
+    if (isPushEnabled) {
+      checkExpr(path, "cast(v as struct<a int>)", Row(null))
+      checkExpr(path, "variant_get(v, '$.a', 'int')", null)
+    } else {
+      checkException(path, "cast(v as struct<a int>)", "MALFORMED_VARIANT")
+      checkException(path, "variant_get(v, '$.a', 'int')", "MALFORMED_VARIANT")
+    }
+
+    // Scalar reader reads from `typed_value` if both `value` and 
`typed_value` are not null.
+    // Cast from `value` succeeds, cast from `typed_value` fails.
+    writeRows(path, "v struct<metadata binary, value binary, typed_value 
string>",
+      Row(metadata(Nil), value("1"), "invalid"))
+    checkException(path, "cast(v as int)", "INVALID_VARIANT_CAST")
+    checkExpr(path, "try_cast(v as int)", null)
+
+    // Cast from `value` fails, cast from `typed_value` succeeds.
+    writeRows(path, "v struct<metadata binary, value binary, typed_value 
string>",
+      Row(metadata(Nil), value("\"invalid\""), "1"))
+    checkExpr(path, "cast(v as int)", 1)
+    checkExpr(path, "try_cast(v as int)", 1)
+  }
+
+  testWithTempPath("extract from shredded object") { path =>
+    val keys1 = Seq("a", "b", "c", "d")
+    val keys2 = Seq("a", "b", "c", "e", "f")
+    writeRows(path, "v struct<metadata binary, value binary, typed_value 
struct<" +
+      "a struct<value binary, typed_value int>, b struct<value binary>," +
+      "c struct<typed_value decimal(20, 10)>>>",
+      // {"a":1,"b":"2","c":3.3,"d":4.4}, d is in the left over value.
+      Row(metadata(keys1), shreddedValue("""{"d": 4.4}""", keys1),
+        Row(Row(null, 1), Row(value("\"2\"")), Row(Decimal("3.3")))),
+      // {"a":5.4,"b":-6,"e":{"f":[true]}}, e is in the left over value.
+      Row(metadata(keys2), shreddedValue("""{"e": {"f": [true]}}""", keys2),
+        Row(Row(value("5.4"), null), Row(value("-6")), Row(null))),
+      // [{"a":1}], the unshredded array at the top-level is put into `value` 
as a whole.
+      Row(metadata(Seq("a")), value("""[{"a": 1}]"""), null))
+
+    checkAnswer(read(path).selectExpr("variant_get(v, '$.a', 'int')",
+      "variant_get(v, '$.b', 'long')", "variant_get(v, '$.c', 'double')",
+      "variant_get(v, '$.d', 'decimal(9, 4)')"),
+      Seq(Row(1, 2L, 3.3, BigDecimal("4.4")), Row(5, -6L, null, null), 
Row(null, null, null, null)))
+    checkExpr(path, "variant_get(v, '$.e.f[0]', 'boolean')", null, true, null)
+    checkExpr(path, "variant_get(v, '$[0].a', 'boolean')", null, null, true)
+    checkExpr(path, "try_cast(v as struct<a float, e variant>)",
+      Row(1.0F, null), Row(5.4F, parseJson("""{"f": [true]}""")), null)
+
+    // String "2" cannot be cast into boolean.
+    checkException(path, "variant_get(v, '$.b', 'boolean')", 
"INVALID_VARIANT_CAST")
+    // Decimal cannot be cast into date.
+    checkException(path, "variant_get(v, '$.c', 'date')", 
"INVALID_VARIANT_CAST")
+    // The value of `c` doesn't fit into `decimal(1, 1)`.
+    checkException(path, "variant_get(v, '$.c', 'decimal(1, 1)')", 
"INVALID_VARIANT_CAST")
+    checkExpr(path, "try_variant_get(v, '$.b', 'boolean')", null, true, null)
+    // Scalar cannot be cast into struct.
+    checkException(path, "variant_get(v, '$.a', 'struct<a int>')", 
"INVALID_VARIANT_CAST")
+    checkExpr(path, "try_variant_get(v, '$.a', 'struct<a int>')", null, null, 
null)
+
+    checkExpr(path, "try_cast(v as map<string, double>)",
+      Map("a" -> 1.0, "b" -> 2.0, "c" -> 3.3, "d" -> 4.4),
+      Map("a" -> 5.4, "b" -> -6.0, "e" -> null), null)
+    checkExpr(path, "try_cast(v as array<string>)", null, null, 
Seq("""{"a":1}"""))
+
+    val strings = Seq("""{"a":1,"b":"2","c":3.3,"d":4.4}""",
+      """{"a":5.4,"b":-6,"e":{"f":[true]}}""", """[{"a":1}]""")
+    checkExpr(path, "cast(v as string)", strings: _*)
+    checkExpr(path, "v",
+      VariantExpressionEvalUtils.castToVariant(
+        InternalRow(1, UTF8String.fromString("2"), Decimal("3.3000000000"), 
Decimal("4.4")),
+        StructType.fromDDL("a int, b string, c decimal(20, 10), d decimal(2, 
1)")
+      ),
+      parseJson(strings(1)),
+      parseJson(strings(2))
+    )
+  }
+
+  testWithTempPath("extract from shredded array") { path =>
+    val keys = Seq("a", "b")
+    writeRows(path, "v struct<metadata binary, value binary, typed_value 
array<" +
+      "struct<value binary, typed_value struct<a struct<value binary, 
typed_value string>>>>>",
+      // [{"a":"2000-01-01"},{"a":"1000-01-01","b":[7]}], b is in the left 
over value.
+      Row(metadata(keys), null, Array(
+        Row(null, Row(Row(null, "2000-01-01"))),
+        Row(shreddedValue("""{"b": [7]}""", keys), Row(Row(null, 
"1000-01-01"))))),
+      // [null,{"a":null},{"a":"null"},{}]
+      Row(metadata(keys), null, Array(
+        Row(value("null"), null),
+        Row(null, Row(Row(value("null"), null))),
+        Row(null, Row(Row(null, "null"))),
+        Row(null, Row(Row(null, null))))))
+
+    val date1 = Date.valueOf("2000-01-01")
+    val date2 = Date.valueOf("1000-01-01")
+    checkExpr(path, "variant_get(v, '$[0].a', 'date')", date1, null)
+    // try_cast succeeds.
+    checkExpr(path, "try_variant_get(v, '$[1].a', 'date')", date2, null)
+    // The first array returns null because of out-of-bound index.
+    // The second array returns "null".
+    checkExpr(path, "variant_get(v, '$[2].a', 'string')", null, "null")
+    // Return null because of invalid cast.
+    checkExpr(path, "try_variant_get(v, '$[1].a', 'int')", null, null)
+
+    checkExpr(path, "variant_get(v, '$[0].b[0]', 'int')", null, null)
+    checkExpr(path, "variant_get(v, '$[1].b[0]', 'int')", 7, null)
+    // Validate timestamp-related casts uses the session time zone correctly.
+    Seq("Etc/UTC", "America/Los_Angeles").foreach { tz =>
+      withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
+        val expected = sql("select timestamp'1000-01-01', 
timestamp_ntz'1000-01-01'").head()
+        checkAnswer(read(path).selectExpr("variant_get(v, '$[1].a', 
'timestamp')",
+          "variant_get(v, '$[1].a', 'timestamp_ntz')"), Seq(expected, 
Row(null, null)))
+      }
+    }
+    checkException(path, "variant_get(v, '$[0]', 'int')", 
"INVALID_VARIANT_CAST")
+    // An out-of-bound array access produces null. It never causes an invalid 
cast.
+    checkExpr(path, "variant_get(v, '$[4]', 'int')", null, null)
+
+    checkExpr(path, "cast(v as array<struct<a string, b array<int>>>)",
+      Seq(Row("2000-01-01", null), Row("1000-01-01", Seq(7))),
+      Seq(null, Row(null, null), Row("null", null), Row(null, null)))
+    checkExpr(path, "cast(v as array<map<string, string>>)",
+      Seq(Map("a" -> "2000-01-01"), Map("a" -> "1000-01-01", "b" -> "[7]")),
+      Seq(null, Map("a" -> null), Map("a" -> "null"), Map()))
+    checkExpr(path, "try_cast(v as array<map<string, date>>)",
+      Seq(Map("a" -> date1), Map("a" -> date2, "b" -> null)),
+      Seq(null, Map("a" -> null), Map("a" -> null), Map()))
+
+    val strings = Seq("""[{"a":"2000-01-01"},{"a":"1000-01-01","b":[7]}]""",
+      """[null,{"a":null},{"a":"null"},{}]""")
+    checkExpr(path, "cast(v as string)", strings: _*)
+    checkExpr(path, "v", strings.map(parseJson): _*)
+  }
+
+  testWithTempPath("missing fields") { path =>
+    writeRows(path, "v struct<metadata binary, typed_value struct<" +
+      "a struct<value binary, typed_value int>, b struct<typed_value int>>>",
+      Row(metadata(Nil), Row(Row(null, null), Row(null))),
+      Row(metadata(Nil), Row(Row(value("null"), null), Row(null))),
+      Row(metadata(Nil), Row(Row(null, 1), Row(null))),
+      Row(metadata(Nil), Row(Row(null, null), Row(2))),
+      Row(metadata(Nil), Row(Row(value("null"), null), Row(2))),
+      Row(metadata(Nil), Row(Row(null, 3), Row(4))))
+
+    val strings = Seq("{}", """{"a":null}""", """{"a":1}""", """{"b":2}""", 
"""{"a":null,"b":2}""",
+      """{"a":3,"b":4}""")
+    checkExpr(path, "cast(v as string)", strings: _*)
+    checkExpr(path, "v", strings.map(parseJson): _*)
+
+    checkExpr(path, "variant_get(v, '$.a', 'string')", null, null, "1", null, 
null, "3")
+    checkExpr(path, "variant_get(v, '$.a')", null, parseJson("null"), 
parseJson("1"), null,
+      parseJson("null"), parseJson("3"))
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to