This is an automated email from the ASF dual-hosted git repository.

pvary pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/main by this push:
     new 9ec1b933ed Spark: Backport support writing shredded variant in 
Iceberg-Spark (#16241)
9ec1b933ed is described below

commit 9ec1b933ed1d9253c6019773a624ba9b3d4d3c0a
Author: pvary <[email protected]>
AuthorDate: Thu May 7 16:48:58 2026 +0200

    Spark: Backport support writing shredded variant in Iceberg-Spark (#16241)
    
    backports #14297
---
 .../apache/iceberg/spark/SparkSQLProperties.java   |    8 +
 .../org/apache/iceberg/spark/SparkWriteConf.java   |   30 +
 .../apache/iceberg/spark/SparkWriteOptions.java    |    6 +
 .../iceberg/spark/source/SparkFormatModels.java    |    4 +-
 .../source/SparkVariantShreddingAnalyzer.java      |   69 ++
 .../apache/iceberg/spark/TestSparkWriteConf.java   |   85 ++
 .../spark/variant/TestVariantShredding.java        | 1101 ++++++++++++++++++++
 7 files changed, 1302 insertions(+), 1 deletion(-)

diff --git 
a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java
 
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java
index b5b8602145..336aadd73c 100644
--- 
a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java
+++ 
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java
@@ -111,4 +111,12 @@ public class SparkSQLProperties {
 
   // Prefix for custom snapshot properties
   public static final String SNAPSHOT_PROPERTY_PREFIX = 
"spark.sql.iceberg.snapshot-property.";
+
+  // Controls whether to shred variant columns during write operations
+  public static final String SHRED_VARIANTS = 
"spark.sql.iceberg.shred-variants";
+
+  // Controls the buffer size for variant schema inference during writes
+  // This determines how many rows are buffered before inferring shredded 
schema
+  public static final String VARIANT_INFERENCE_BUFFER_SIZE =
+      "spark.sql.iceberg.variant-inference-buffer-size";
 }
diff --git 
a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java 
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java
index aba7e4dda0..add12e6040 100644
--- 
a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java
+++ 
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java
@@ -33,6 +33,8 @@ import static 
org.apache.iceberg.TableProperties.ORC_COMPRESSION;
 import static org.apache.iceberg.TableProperties.ORC_COMPRESSION_STRATEGY;
 import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION;
 import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION_LEVEL;
+import static org.apache.iceberg.TableProperties.PARQUET_SHRED_VARIANTS;
+import static org.apache.iceberg.TableProperties.PARQUET_VARIANT_BUFFER_SIZE;
 import static 
org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE;
 
 import java.util.Locale;
@@ -529,6 +531,14 @@ public class SparkWriteConf {
         if (parquetCompressionLevel != null) {
           writeProperties.put(PARQUET_COMPRESSION_LEVEL, 
parquetCompressionLevel);
         }
+        boolean shouldShredVariants = shredVariants();
+        writeProperties.put(PARQUET_SHRED_VARIANTS, 
String.valueOf(shouldShredVariants));
+
+        // Add variant shredding configuration properties
+        if (shouldShredVariants) {
+          writeProperties.put(
+              PARQUET_VARIANT_BUFFER_SIZE, 
String.valueOf(variantInferenceBufferSize()));
+        }
         break;
 
       case AVRO:
@@ -749,4 +759,24 @@ public class SparkWriteConf {
         .defaultValue(DeleteGranularity.FILE)
         .parse();
   }
+
+  public boolean shredVariants() {
+    return confParser
+        .booleanConf()
+        .option(SparkWriteOptions.SHRED_VARIANTS)
+        .sessionConf(SparkSQLProperties.SHRED_VARIANTS)
+        .tableProperty(TableProperties.PARQUET_SHRED_VARIANTS)
+        .defaultValue(TableProperties.PARQUET_SHRED_VARIANTS_DEFAULT)
+        .parse();
+  }
+
+  public int variantInferenceBufferSize() {
+    return confParser
+        .intConf()
+        .option(SparkWriteOptions.VARIANT_INFERENCE_BUFFER_SIZE)
+        .sessionConf(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE)
+        .tableProperty(TableProperties.PARQUET_VARIANT_BUFFER_SIZE)
+        .defaultValue(TableProperties.PARQUET_VARIANT_BUFFER_SIZE_DEFAULT)
+        .parse();
+  }
 }
diff --git 
a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java
 
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java
index 1be02feaf0..6c76b5c873 100644
--- 
a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java
+++ 
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java
@@ -86,4 +86,10 @@ public class SparkWriteOptions {
 
   // Overrides the delete granularity
   public static final String DELETE_GRANULARITY = "delete-granularity";
+
+  // Controls whether to shred variant columns during write operations
+  public static final String SHRED_VARIANTS = "shred-variants";
+
+  // Controls the buffer size for variant schema inference during writes
+  public static final String VARIANT_INFERENCE_BUFFER_SIZE = 
"variant-inference-buffer-size";
 }
diff --git 
a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java
 
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java
index 23fbe54a4b..5b7862116a 100644
--- 
a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java
+++ 
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java
@@ -51,7 +51,9 @@ public class SparkFormatModels {
             StructType.class,
             SparkParquetWriters::buildWriter,
             (icebergSchema, fileSchema, engineSchema, idToConstant) ->
-                SparkParquetReaders.buildReader(icebergSchema, fileSchema, 
idToConstant)));
+                SparkParquetReaders.buildReader(icebergSchema, fileSchema, 
idToConstant),
+            new SparkVariantShreddingAnalyzer(),
+            InternalRow::copy));
 
     FormatModelRegistry.register(
         ParquetFormatModel.create(
diff --git 
a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java
 
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java
new file mode 100644
index 0000000000..2c08c662c9
--- /dev/null
+++ 
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.source;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.List;
+import org.apache.iceberg.parquet.VariantShreddingAnalyzer;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.variants.VariantMetadata;
+import org.apache.iceberg.variants.VariantValue;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.unsafe.types.VariantVal;
+
+/**
+ * Spark-specific implementation that extracts variant values from {@link 
InternalRow} instances.
+ */
+class SparkVariantShreddingAnalyzer extends 
VariantShreddingAnalyzer<InternalRow, StructType> {
+
+  SparkVariantShreddingAnalyzer() {}
+
+  @Override
+  protected int resolveColumnIndex(StructType sparkSchema, String columnName) {
+    try {
+      return sparkSchema.fieldIndex(columnName);
+    } catch (IllegalArgumentException e) {
+      return -1;
+    }
+  }
+
+  @Override
+  protected List<VariantValue> extractVariantValues(
+      List<InternalRow> bufferedRows, int variantFieldIndex) {
+    List<VariantValue> values = Lists.newArrayList();
+
+    for (InternalRow row : bufferedRows) {
+      if (!row.isNullAt(variantFieldIndex)) {
+        VariantVal variantVal = row.getVariant(variantFieldIndex);
+        if (variantVal != null) {
+          VariantValue variantValue =
+              VariantValue.from(
+                  VariantMetadata.from(
+                      
ByteBuffer.wrap(variantVal.getMetadata()).order(ByteOrder.LITTLE_ENDIAN)),
+                  
ByteBuffer.wrap(variantVal.getValue()).order(ByteOrder.LITTLE_ENDIAN));
+          values.add(variantValue);
+        }
+      }
+    }
+
+    return values;
+  }
+}
diff --git 
a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java
 
b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java
index c83b1b6e26..c5cfbe62b1 100644
--- 
a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java
+++ 
b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java
@@ -34,6 +34,7 @@ import static 
org.apache.iceberg.TableProperties.ORC_COMPRESSION;
 import static org.apache.iceberg.TableProperties.ORC_COMPRESSION_STRATEGY;
 import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION;
 import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION_LEVEL;
+import static org.apache.iceberg.TableProperties.PARQUET_SHRED_VARIANTS;
 import static org.apache.iceberg.TableProperties.UPDATE_DISTRIBUTION_MODE;
 import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE;
 import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_HASH;
@@ -61,6 +62,7 @@ import org.apache.iceberg.deletes.DeleteGranularity;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
 import org.apache.iceberg.relocated.com.google.common.collect.Lists;
 import org.apache.spark.sql.internal.SQLConf;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.TestTemplate;
@@ -340,6 +342,8 @@ public class TestSparkWriteConf extends TestBaseWithCatalog 
{
                     TableProperties.DELETE_PARQUET_COMPRESSION,
                     "snappy"),
                 ImmutableMap.of(
+                    PARQUET_SHRED_VARIANTS,
+                    "false",
                     DELETE_PARQUET_COMPRESSION,
                     "zstd",
                     PARQUET_COMPRESSION,
@@ -461,6 +465,8 @@ public class TestSparkWriteConf extends TestBaseWithCatalog 
{
                     PARQUET_COMPRESSION_LEVEL,
                     "5"),
                 ImmutableMap.of(
+                    PARQUET_SHRED_VARIANTS,
+                    "false",
                     DELETE_PARQUET_COMPRESSION,
                     "zstd",
                     PARQUET_COMPRESSION,
@@ -532,6 +538,8 @@ public class TestSparkWriteConf extends TestBaseWithCatalog 
{
                     DELETE_PARQUET_COMPRESSION_LEVEL,
                     "6"),
                 ImmutableMap.of(
+                    PARQUET_SHRED_VARIANTS,
+                    "false",
                     DELETE_PARQUET_COMPRESSION,
                     "zstd",
                     PARQUET_COMPRESSION,
@@ -686,4 +694,81 @@ public class TestSparkWriteConf extends 
TestBaseWithCatalog {
     
assertThat(writeConf.copyOnWriteDistributionMode(MERGE)).isEqualTo(expectedMode);
     
assertThat(writeConf.positionDeltaDistributionMode(MERGE)).isEqualTo(expectedMode);
   }
+
+  @TestTemplate
+  public void testShredVariantsDefault() {
+    Table table = validationCatalog.loadTable(tableIdent);
+    SparkWriteConf writeConf = new SparkWriteConf(spark, table, 
ImmutableMap.of());
+    assertThat(writeConf.shredVariants()).isFalse();
+  }
+
+  @TestTemplate
+  public void testVariantInferenceBufferSizeDefault() {
+    Table table = validationCatalog.loadTable(tableIdent);
+    SparkWriteConf writeConf = new SparkWriteConf(spark, table, 
ImmutableMap.of());
+    assertThat(writeConf.variantInferenceBufferSize())
+        .isEqualTo(TableProperties.PARQUET_VARIANT_BUFFER_SIZE_DEFAULT);
+  }
+
+  @TestTemplate
+  public void testVariantInferenceBufferSizeTableProperty() {
+    Table table = validationCatalog.loadTable(tableIdent);
+
+    table.updateProperties().set(TableProperties.PARQUET_VARIANT_BUFFER_SIZE, 
"500").commit();
+
+    SparkWriteConf writeConf = new SparkWriteConf(spark, table, 
ImmutableMap.of());
+    assertThat(writeConf.variantInferenceBufferSize()).isEqualTo(500);
+  }
+
+  @TestTemplate
+  public void testShredVariantsSessionOverridesTableProperty() {
+    Table table = validationCatalog.loadTable(tableIdent);
+    table.updateProperties().set(TableProperties.PARQUET_SHRED_VARIANTS, 
"false").commit();
+
+    withSQLConf(
+        ImmutableMap.of(SparkSQLProperties.SHRED_VARIANTS, "true"),
+        () -> {
+          SparkWriteConf writeConf = new SparkWriteConf(spark, table, 
ImmutableMap.of());
+          assertThat(writeConf.shredVariants()).isTrue();
+        });
+  }
+
+  @TestTemplate
+  public void testShredVariantsWriteOptionOverridesSessionConf() {
+    withSQLConf(
+        ImmutableMap.of(SparkSQLProperties.SHRED_VARIANTS, "false"),
+        () -> {
+          Table table = validationCatalog.loadTable(tableIdent);
+          SparkWriteConf writeConf =
+              new SparkWriteConf(
+                  spark,
+                  table,
+                  new CaseInsensitiveStringMap(
+                      ImmutableMap.of(SparkWriteOptions.SHRED_VARIANTS, 
"true")));
+          assertThat(writeConf.shredVariants()).isTrue();
+        });
+  }
+
+  @TestTemplate
+  public void testVariantInferenceBufferSizeSessionConf() {
+    withSQLConf(
+        ImmutableMap.of(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, 
"250"),
+        () -> {
+          Table table = validationCatalog.loadTable(tableIdent);
+          SparkWriteConf writeConf = new SparkWriteConf(spark, table, 
ImmutableMap.of());
+          assertThat(writeConf.variantInferenceBufferSize()).isEqualTo(250);
+        });
+  }
+
+  @TestTemplate
+  public void testWritePropertiesIncludeVariantShredding() {
+    Table table = validationCatalog.loadTable(tableIdent);
+    table.updateProperties().set(TableProperties.PARQUET_SHRED_VARIANTS, 
"true").commit();
+    table.updateProperties().set(TableProperties.PARQUET_VARIANT_BUFFER_SIZE, 
"200").commit();
+
+    SparkWriteConf writeConf = new SparkWriteConf(spark, table, 
ImmutableMap.of());
+    Map<String, String> writeProperties = writeConf.writeProperties();
+    assertThat(writeProperties).containsEntry(PARQUET_SHRED_VARIANTS, "true");
+    
assertThat(writeProperties).containsEntry(TableProperties.PARQUET_VARIANT_BUFFER_SIZE,
 "200");
+  }
 }
diff --git 
a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java
 
b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java
new file mode 100644
index 0000000000..8cdcf22e58
--- /dev/null
+++ 
b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java
@@ -0,0 +1,1101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.variant;
+
+import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS;
+import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES;
+import static org.apache.parquet.schema.Types.optional;
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.net.InetAddress;
+import java.util.List;
+import java.util.Map;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.iceberg.FileScanTask;
+import org.apache.iceberg.Parameters;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.TableProperties;
+import org.apache.iceberg.io.CloseableIterable;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.spark.CatalogTestBase;
+import org.apache.iceberg.spark.SparkCatalogConfig;
+import org.apache.iceberg.spark.SparkSQLProperties;
+import org.apache.iceberg.types.Types;
+import org.apache.iceberg.variants.Variant;
+import org.apache.parquet.hadoop.ParquetFileReader;
+import org.apache.parquet.hadoop.util.HadoopInputFile;
+import org.apache.parquet.schema.GroupType;
+import org.apache.parquet.schema.LogicalTypeAnnotation;
+import org.apache.parquet.schema.MessageType;
+import org.apache.parquet.schema.PrimitiveType;
+import org.apache.parquet.schema.Type;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
+import org.apache.spark.sql.internal.SQLConf;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.TestTemplate;
+
+public class TestVariantShredding extends CatalogTestBase {
+
+  private static final Schema SCHEMA =
+      new Schema(
+          Types.NestedField.required(1, "id", Types.IntegerType.get()),
+          Types.NestedField.optional(2, "address", Types.VariantType.get()));
+
+  private static final Schema SCHEMA2 =
+      new Schema(
+          Types.NestedField.required(1, "id", Types.IntegerType.get()),
+          Types.NestedField.optional(2, "address", Types.VariantType.get()),
+          Types.NestedField.optional(3, "metadata", Types.VariantType.get()));
+
+  @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}")
+  protected static Object[][] parameters() {
+    return new Object[][] {
+      {
+        SparkCatalogConfig.HADOOP.catalogName(),
+        SparkCatalogConfig.HADOOP.implementation(),
+        SparkCatalogConfig.HADOOP.properties()
+      },
+    };
+  }
+
+  @BeforeAll
+  public static void startMetastoreAndSpark() {
+    // First call parent to initialize metastore and spark with local[2]
+    CatalogTestBase.startMetastoreAndSpark();
+
+    // Now stop and recreate spark with local[1] to write all rows to a single 
file
+    if (spark != null) {
+      spark.stop();
+    }
+
+    spark =
+        SparkSession.builder()
+            .master("local[1]") // Use one thread to write the rows to a 
single parquet file
+            .config("spark.driver.host", 
InetAddress.getLoopbackAddress().getHostAddress())
+            .config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic")
+            .config("spark.hadoop." + METASTOREURIS.varname, 
hiveConf.get(METASTOREURIS.varname))
+            
.config("spark.sql.legacy.respectNullabilityInTextDatasetConversion", "true")
+            .config(DISABLE_UI)
+            .enableHiveSupport()
+            .getOrCreate();
+
+    sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext());
+  }
+
+  @BeforeEach
+  public void before() {
+    super.before();
+    validationCatalog.createTable(
+        tableIdent, SCHEMA, null, Map.of(TableProperties.FORMAT_VERSION, "3"));
+  }
+
+  @AfterEach
+  public void after() {
+    spark.conf().unset(SparkSQLProperties.SHRED_VARIANTS);
+    spark.conf().unset(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE);
+    validationCatalog.dropTable(tableIdent, true);
+  }
+
+  @TestTemplate
+  public void testVariantShreddingDisabled() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "false");
+
+    String values = "(1, parse_json('{\"city\": \"NYC\", \"zip\": 10001}')), 
(2, null)";
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    GroupType address = variant("address", 2, Type.Repetition.OPTIONAL);
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+  }
+
+  @TestTemplate
+  public void testExcludingNullValue() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+    String values =
+        """
+            (1, parse_json('{"name": "Alice", "age": 30, "dummy": null}')),\
+             (2, parse_json('{"name": "Bob", "age": 25}')),\
+             (3, parse_json('{"name": "Charlie", "age": 35}'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    GroupType name =
+        field(
+            "name",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.BINARY, 
LogicalTypeAnnotation.stringType()));
+    GroupType age =
+        field(
+            "age",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.INT32, 
LogicalTypeAnnotation.intType(8, true)));
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(age, name));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+  }
+
+  @TestTemplate
+  public void testInconsistentType() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+    String values =
+        """
+            (1, parse_json('{"age": "25"}')),\
+             (2, parse_json('{"age": 30}')),\
+             (3, parse_json('{"age": "35"}'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    GroupType age =
+        field(
+            "age",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.BINARY, 
LogicalTypeAnnotation.stringType()));
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(age));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+
+    List<Object[]> rows =
+        sql("SELECT variant_get(address, '$.age', 'int') FROM %s WHERE id = 
2", tableName);
+    assertThat(rows).hasSize(1);
+    assertThat(rows.get(0)[0]).isEqualTo(30);
+  }
+
+  @TestTemplate
+  public void testPrimitiveType() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+    String values = "(1, parse_json('123')), (2, parse_json('456')), (3, 
parse_json('789'))";
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    GroupType address =
+        variant(
+            "address",
+            2,
+            Type.Repetition.REQUIRED,
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.INT32, 
LogicalTypeAnnotation.intType(16, true)));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+  }
+
+  @TestTemplate
+  public void testPrimitiveDecimalType() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+    String values =
+        "(1, parse_json('123.56')), (2, parse_json('\"abc\"')), (3, 
parse_json('12.56'))";
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    GroupType address =
+        variant(
+            "address",
+            2,
+            Type.Repetition.REQUIRED,
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.INT32, 
LogicalTypeAnnotation.decimalType(2, 5)));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+  }
+
+  @TestTemplate
+  public void testBooleanType() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+    String values =
+        """
+            (1, parse_json('{"active": true}')),\
+             (2, parse_json('{"active": false}')),\
+             (3, parse_json('{"active": true}'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    GroupType active = field("active", 
shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN));
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(active));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+  }
+
+  @TestTemplate
+  public void testDecimalTypeWithInconsistentScales() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+    String values =
+        """
+            (1, parse_json('{"price": 123.456789}')),\
+             (2, parse_json('{"price": 678.90}')),\
+             (3, parse_json('{"price": 999.99}'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    GroupType price =
+        field(
+            "price",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.INT32, 
LogicalTypeAnnotation.decimalType(6, 9)));
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(price));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+  }
+
+  @TestTemplate
+  public void testDecimalTypeWithConsistentScales() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+    String values =
+        """
+            (1, parse_json('{"price": 123.45}')),\
+             (2, parse_json('{"price": 678.90}')),\
+             (3, parse_json('{"price": 999.99}'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    GroupType price =
+        field(
+            "price",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.INT32, 
LogicalTypeAnnotation.decimalType(2, 5)));
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(price));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+  }
+
+  @TestTemplate
+  public void testArrayType() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+    String values =
+        """
+            (1, parse_json('["java", "scala", "python"]')),\
+             (2, parse_json('["rust", "go"]')),\
+             (3, parse_json('["javascript"]'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    GroupType arr =
+        list(
+            element(
+                shreddedPrimitive(
+                    PrimitiveType.PrimitiveTypeName.BINARY, 
LogicalTypeAnnotation.stringType())));
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED, arr);
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+  }
+
+  @TestTemplate
+  public void testNestedArrayType() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+    String values =
+        """
+            (1, parse_json('{"tags": ["java", "scala", "python"]}')),\
+             (2, parse_json('{"tags": ["rust", "go"]}')),\
+             (3, parse_json('{"tags": ["javascript"]}'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    GroupType tags =
+        field(
+            "tags",
+            list(
+                element(
+                    shreddedPrimitive(
+                        PrimitiveType.PrimitiveTypeName.BINARY,
+                        LogicalTypeAnnotation.stringType()))));
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(tags));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+  }
+
+  @TestTemplate
+  public void testNestedObjectType() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+    String values =
+        """
+            (1, parse_json('{"location": {"city": "Seattle", "zip": 98101}, 
"tags": ["java", "scala", "python"]}')),\
+             (2, parse_json('{"location": {"city": "Portland", "zip": 
97201}}')),\
+             (3, parse_json('{"location": {"city": "NYC", "zip": 10001}}'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    GroupType city =
+        field(
+            "city",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.BINARY, 
LogicalTypeAnnotation.stringType()));
+    GroupType zip =
+        field(
+            "zip",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.INT32, 
LogicalTypeAnnotation.intType(32, true)));
+    GroupType location = field("location", objectFields(city, zip));
+    GroupType tags =
+        field(
+            "tags",
+            list(
+                element(
+                    shreddedPrimitive(
+                        PrimitiveType.PrimitiveTypeName.BINARY,
+                        LogicalTypeAnnotation.stringType()))));
+
+    GroupType address =
+        variant("address", 2, Type.Repetition.REQUIRED, objectFields(location, 
tags));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+  }
+
+  @TestTemplate
+  public void testLazyInitializationWithBufferedRows() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+    spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "5");
+
+    String values =
+        """
+            (1, parse_json('{"name": "Alice", "age": 30}')),\
+             (2, parse_json('{"name": "Bob", "age": 25}')),\
+             (3, parse_json('{"name": "Charlie", "age": 35}')),\
+             (4, parse_json('{"name": "David", "age": 28}')),\
+             (5, parse_json('{"name": "Eve", "age": 32}')),\
+             (6, parse_json('{"name": "Frank", "age": 40}')),\
+             (7, parse_json('{"name": "Grace", "age": 27}'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    GroupType name =
+        field(
+            "name",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.BINARY, 
LogicalTypeAnnotation.stringType()));
+    GroupType age =
+        field(
+            "age",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.INT32, 
LogicalTypeAnnotation.intType(8, true)));
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(age, name));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+
+    long rowCount = spark.read().format("iceberg").load(tableName).count();
+    assertThat(rowCount).isEqualTo(7);
+  }
+
+  @TestTemplate
+  public void testMultipleRowGroups() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+    spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3");
+
+    int numRows = 1000;
+    StringBuilder valuesBuilder = new StringBuilder();
+    for (int i = 1; i <= numRows; i++) {
+      if (i > 1) {
+        valuesBuilder.append(", ");
+      }
+      valuesBuilder.append(
+          String.format("(%d, parse_json('{\"name\": \"User%d\", \"age\": 
%d}'))", i, i, 20 + i));
+    }
+    sql(
+        "ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')",
+        tableName, PARQUET_ROW_GROUP_SIZE_BYTES, 1024);
+    sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString());
+
+    GroupType name =
+        field(
+            "name",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.BINARY, 
LogicalTypeAnnotation.stringType()));
+    GroupType age =
+        field(
+            "age",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.INT32, 
LogicalTypeAnnotation.intType(8, true)));
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(age, name));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+
+    long rowCount = spark.read().format("iceberg").load(tableName).count();
+    assertThat(rowCount).isEqualTo(numRows);
+  }
+
+  @TestTemplate
+  public void testColumnIndexTruncateLength() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+    spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3");
+
+    int customTruncateLength = 10;
+    sql(
+        "ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')",
+        tableName, "parquet.columnindex.truncate.length", 
customTruncateLength);
+
+    StringBuilder valuesBuilder = new StringBuilder();
+    for (int i = 1; i <= 10; i++) {
+      if (i > 1) {
+        valuesBuilder.append(", ");
+      }
+      String longValue = "A".repeat(20);
+      valuesBuilder.append(
+          String.format(
+              "(%d, parse_json('{\"description\": \"%s\", \"id\": %d}'))", i, 
longValue, i));
+    }
+    sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString());
+
+    GroupType description =
+        field(
+            "description",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.BINARY, 
LogicalTypeAnnotation.stringType()));
+    GroupType id =
+        field(
+            "id",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.INT32, 
LogicalTypeAnnotation.intType(8, true)));
+    GroupType address =
+        variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(description, id));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+
+    long rowCount = spark.read().format("iceberg").load(tableName).count();
+    assertThat(rowCount).isEqualTo(10);
+  }
+
+  @TestTemplate
+  public void testIntegerFamilyPromotion() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+    // Mix of INT8, INT16, INT32, INT64 - should promote to INT64
+    String values =
+        """
+            (1, parse_json('{"value": 10}')),\
+             (2, parse_json('{"value": 1000}')),\
+             (3, parse_json('{"value": 100000}')),\
+             (4, parse_json('{"value": 10000000000}'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    GroupType value =
+        field(
+            "value",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.INT64, 
LogicalTypeAnnotation.intType(64, true)));
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(value));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+  }
+
+  @TestTemplate
+  public void testDecimalFamilyPromotion() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+    // Test that they get promoted to the most capable decimal type observed
+    String values =
+        """
+            (1, parse_json('{"value": 1.5}')),\
+             (2, parse_json('{"value": 123.456789}')),\
+             (3, parse_json('{"value": 123456789123456.789}'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    GroupType value =
+        field(
+            "value",
+            optional(PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY)
+                .length(16)
+                .as(LogicalTypeAnnotation.decimalType(6, 21))
+                .named("typed_value"));
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(value));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+  }
+
+  @TestTemplate
+  public void testDataRoundTripWithShredding() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+    String values =
+        """
+            (1, parse_json('{"name": "Alice", "age": 30}')),\
+             (2, parse_json('{"name": "Bob", "age": 25}')),\
+             (3, parse_json('{"name": "Charlie", "age": 35}'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    GroupType name =
+        field(
+            "name",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.BINARY, 
LogicalTypeAnnotation.stringType()));
+    GroupType age =
+        field(
+            "age",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.INT32, 
LogicalTypeAnnotation.intType(8, true)));
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(age, name));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+
+    // Verify that we can read the data back correctly
+    List<Object[]> rows =
+        sql(
+            "SELECT id, variant_get(address, '$.name', 'string'),"
+                + " variant_get(address, '$.age', 'int')"
+                + " FROM %s ORDER BY id",
+            tableName);
+    assertThat(rows).hasSize(3);
+    assertThat(rows.get(0)[0]).isEqualTo(1);
+    assertThat(rows.get(0)[1]).isEqualTo("Alice");
+    assertThat(rows.get(0)[2]).isEqualTo(30);
+    assertThat(rows.get(1)[0]).isEqualTo(2);
+    assertThat(rows.get(1)[1]).isEqualTo("Bob");
+    assertThat(rows.get(1)[2]).isEqualTo(25);
+    assertThat(rows.get(2)[0]).isEqualTo(3);
+    assertThat(rows.get(2)[1]).isEqualTo("Charlie");
+    assertThat(rows.get(2)[2]).isEqualTo(35);
+  }
+
+  @TestTemplate
+  public void testMultipleVariantsWithShredding() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+    // Recreate table with SCHEMA2 (address + metadata variant columns)
+    validationCatalog.dropTable(tableIdent, true);
+    validationCatalog.createTable(
+        tableIdent, SCHEMA2, null, Map.of(TableProperties.FORMAT_VERSION, 
"3"));
+
+    String values =
+        """
+            (1, parse_json('{"city": "NYC"}'), parse_json('{"source": 
"web"}')),\
+             (2, parse_json('{"city": "LA"}'), parse_json('{"source": 
"app"}')),\
+             (3, parse_json('{"city": "SF"}'), parse_json('{"source": 
"api"}'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    GroupType city =
+        field(
+            "city",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.BINARY, 
LogicalTypeAnnotation.stringType()));
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(city));
+
+    GroupType source =
+        field(
+            "source",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.BINARY, 
LogicalTypeAnnotation.stringType()));
+    GroupType metadata = variant("metadata", 3, Type.Repetition.REQUIRED, 
objectFields(source));
+    MessageType expectedSchema = parquetSchema(address, metadata);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+  }
+
+  @TestTemplate
+  public void testVariantWithNullValues() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+    String values =
+        """
+            (1, parse_json('null')),\
+             (2, parse_json('null')),\
+             (3, parse_json('null'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED);
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+  }
+
+  @TestTemplate
+  public void testArrayOfNullElementsWithShredding() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+    sql(
+        "INSERT INTO %s VALUES (1, parse_json('[null, null, null]')), "
+            + "(2, parse_json('[null]'))",
+        tableName);
+
+    // Array elements are all null, element type is null, falls back to 
unshredded
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED);
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+  }
+
+  @TestTemplate
+  public void testMixedNullAndNonNullVariantValues() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+    String values =
+        """
+            (1, parse_json('{"name": "Alice", "age": 30}')),\
+             (2, null),\
+             (3, parse_json('{"name": "Charlie", "age": 35}'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    GroupType name =
+        field(
+            "name",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.BINARY, 
LogicalTypeAnnotation.stringType()));
+    GroupType age =
+        field(
+            "age",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.INT32, 
LogicalTypeAnnotation.intType(8, true)));
+    GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, 
objectFields(age, name));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+
+    long rowCount = spark.read().format("iceberg").load(tableName).count();
+    assertThat(rowCount).isEqualTo(3);
+  }
+
+  @TestTemplate
+  public void testWriteOptionOverridesSessionConfig() throws IOException, 
NoSuchTableException {
+    // Disable shredding at session level
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "false");
+
+    // Enable shredding via per-write option
+    String query =
+        "SELECT 1 as id, parse_json('{\"name\": \"Alice\", \"age\": 30}') as 
address"
+            + " UNION ALL SELECT 2, parse_json('{\"name\": \"Bob\", \"age\": 
25}')"
+            + " UNION ALL SELECT 3, parse_json('{\"name\": \"Charlie\", 
\"age\": 35}')";
+    spark.sql(query).writeTo(tableName).option("shred-variants", 
"true").append();
+
+    GroupType name =
+        field(
+            "name",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.BINARY, 
LogicalTypeAnnotation.stringType()));
+    GroupType age =
+        field(
+            "age",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.INT32, 
LogicalTypeAnnotation.intType(8, true)));
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(age, name));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+  }
+
+  @TestTemplate
+  public void testInfrequentFieldPruning() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+    spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "11");
+
+    StringBuilder valuesBuilder = new StringBuilder();
+    for (int i = 1; i <= 11; i++) {
+      if (i > 1) {
+        valuesBuilder.append(", ");
+      }
+      if (i == 1) {
+        // Only the first row has rare_field
+        valuesBuilder.append(
+            String.format(
+                "(%d, parse_json('{\"name\": \"User%d\", \"rare_field\": 
\"rare\"}'))", i, i));
+      } else {
+        valuesBuilder.append(String.format("(%d, parse_json('{\"name\": 
\"User%d\"}'))", i, i));
+      }
+    }
+    sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString());
+
+    // rare_field appears in 1/11 rows, should be pruned
+    // name appears in 11/11 rows and should be kept
+    GroupType name =
+        field(
+            "name",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.BINARY, 
LogicalTypeAnnotation.stringType()));
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(name));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+  }
+
+  @TestTemplate
+  public void testMixedTypeTieBreaking() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+    spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "10");
+
+    StringBuilder valuesBuilder = new StringBuilder();
+    for (int i = 1; i <= 10; i++) {
+      if (i > 1) {
+        valuesBuilder.append(", ");
+      }
+      if (i <= 5) {
+        valuesBuilder.append(String.format("(%d, parse_json('{\"val\": 
%d}'))", i, i));
+      } else {
+        valuesBuilder.append(String.format("(%d, parse_json('{\"val\": 
\"text%d\"}'))", i, i));
+      }
+    }
+    sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString());
+
+    // 5 ints + 5 strings is a tie so STRING wins (higher TIE_BREAK_PRIORITY)
+    GroupType val =
+        field(
+            "val",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.BINARY, 
LogicalTypeAnnotation.stringType()));
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(val));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+
+    // Verify data round-trips correctly
+    List<Object[]> rows =
+        sql("SELECT id, variant_get(address, '$.val', 'string') FROM %s ORDER 
BY id", tableName);
+    assertThat(rows).hasSize(10);
+    assertThat(rows.get(0)[1]).isEqualTo("1");
+    assertThat(rows.get(5)[1]).isEqualTo("text6");
+  }
+
+  @TestTemplate
+  public void testFieldOnlyAfterBuffer() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+    spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3");
+
+    String values =
+        """
+            (1, parse_json('{"name": "Alice"}')),\
+             (2, parse_json('{"name": "Bob"}')),\
+             (3, parse_json('{"name": "Charlie"}')),\
+             (4, parse_json('{"name": "David", "score": 95}')),\
+             (5, parse_json('{"name": "Eve", "score": 88}')),\
+             (6, parse_json('{"name": "Frank", "score": 72}')),\
+             (7, parse_json('{"name": "Grace", "score": 91}'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    // Schema is determined from buffer (rows 1-3) which only has "name".
+    // "score" is not shredded
+    GroupType name =
+        field(
+            "name",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.BINARY, 
LogicalTypeAnnotation.stringType()));
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(name));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+
+    // Verify all data round-trips despite "score" not being shredded
+    List<Object[]> rows =
+        sql(
+            "SELECT id, variant_get(address, '$.name', 'string'),"
+                + " variant_get(address, '$.score', 'int')"
+                + " FROM %s ORDER BY id",
+            tableName);
+    assertThat(rows).hasSize(7);
+    assertThat(rows.get(0)[1]).isEqualTo("Alice");
+    assertThat(rows.get(0)[2]).isNull();
+    assertThat(rows.get(3)[1]).isEqualTo("David");
+    assertThat(rows.get(3)[2]).isEqualTo(95);
+    assertThat(rows.get(6)[1]).isEqualTo("Grace");
+    assertThat(rows.get(6)[2]).isEqualTo(91);
+  }
+
+  @TestTemplate
+  public void testCrossFileDifferentShreddedType() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+    spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3");
+
+    // File 1: "score" is always integer → shredded as INT8
+    String batch1 =
+        """
+            (1, parse_json('{"score": 95}')),\
+             (2, parse_json('{"score": 88}')),\
+             (3, parse_json('{"score": 72}'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, batch1);
+
+    // Verify file 1 schema: score shredded as INT8
+    Table table = validationCatalog.loadTable(tableIdent);
+    GroupType scoreInt =
+        field(
+            "score",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.INT32, 
LogicalTypeAnnotation.intType(8, true)));
+    MessageType expectedSchema1 =
+        parquetSchema(variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(scoreInt)));
+    verifyParquetSchema(table, expectedSchema1);
+
+    // File 2: "score" is always string → shredded as STRING
+    String batch2 =
+        """
+            (4, parse_json('{"score": "high"}')),\
+             (5, parse_json('{"score": "medium"}')),\
+             (6, parse_json('{"score": "low"}'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, batch2);
+
+    // Query across both files, reader must handle different shredded types
+    List<Object[]> rows =
+        sql("SELECT id, variant_get(address, '$.score', 'string') FROM %s 
ORDER BY id", tableName);
+    assertThat(rows).hasSize(6);
+    assertThat(rows.get(0)[1]).isEqualTo("95");
+    assertThat(rows.get(1)[1]).isEqualTo("88");
+    assertThat(rows.get(3)[1]).isEqualTo("high");
+    assertThat(rows.get(5)[1]).isEqualTo("low");
+  }
+
+  @TestTemplate
+  public void testAllNullVariantColumn() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+    sql("INSERT INTO %s VALUES (1, null), (2, null), (3, null)", tableName);
+
+    // All variant values are SQL NULL, so no shredding should occur
+    Table table = validationCatalog.loadTable(tableIdent);
+    MessageType expectedSchema = parquetSchema(variant("address", 2, 
Type.Repetition.OPTIONAL));
+    verifyParquetSchema(table, expectedSchema);
+
+    List<Object[]> rows = sql("SELECT id, address FROM %s ORDER BY id", 
tableName);
+    assertThat(rows).hasSize(3);
+    assertThat(rows.get(0)[1]).isNull();
+    assertThat(rows.get(1)[1]).isNull();
+    assertThat(rows.get(2)[1]).isNull();
+  }
+
+  @TestTemplate
+  public void testBufferSizeOne() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+    spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "1");
+
+    sql(
+        """
+            INSERT INTO %s VALUES
+            (1, parse_json('{"name": "Alice", "age": 30}')),
+            (2, parse_json('{"name": "Bob", "age": 25}')),
+            (3, parse_json('{"name": "Charlie", "age": 35}'))
+            """,
+        tableName);
+
+    // Schema inferred from first row only, should still shred name and age
+    GroupType age =
+        field(
+            "age",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.INT32, 
LogicalTypeAnnotation.intType(8, true)));
+    GroupType name =
+        field(
+            "name",
+            shreddedPrimitive(
+                PrimitiveType.PrimitiveTypeName.BINARY, 
LogicalTypeAnnotation.stringType()));
+    GroupType address = variant("address", 2, Type.Repetition.REQUIRED, 
objectFields(age, name));
+    MessageType expectedSchema = parquetSchema(address);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    verifyParquetSchema(table, expectedSchema);
+
+    List<Object[]> rows =
+        sql("SELECT id, variant_get(address, '$.name', 'string') FROM %s ORDER 
BY id", tableName);
+    assertThat(rows).hasSize(3);
+    assertThat(rows.get(0)[1]).isEqualTo("Alice");
+    assertThat(rows.get(2)[1]).isEqualTo("Charlie");
+  }
+
+  @TestTemplate
+  public void testDecimalFallbackAfterBuffer() throws IOException {
+    spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+    spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3");
+
+    // Buffer: scale=2, 3 integer digits -> DECIMAL(5,2)
+    // Row 4: precision overflow -> fallback to value field
+    // Row 5: scale overflow -> fallback to value field
+    // Row 6: fits typed column, scale widened from 1 to 2 via setScale
+    String values =
+        """
+            (1, parse_json('{"val": 123.45}')),\
+             (2, parse_json('{"val": 678.90}')),\
+             (3, parse_json('{"val": 999.99}')),\
+             (4, parse_json('{"val": 123456.78}')),\
+             (5, parse_json('{"val": 1.2345}')),\
+             (6, parse_json('{"val": 12.3}'))\
+            """;
+    sql("INSERT INTO %s VALUES %s", tableName, values);
+
+    List<Object[]> rows =
+        sql(
+            "SELECT id, variant_get(address, '$.val', 'decimal(10,4)') FROM %s 
ORDER BY id",
+            tableName);
+    assertThat(rows).hasSize(6);
+    assertThat(rows.get(0)[1]).isEqualTo(new BigDecimal("123.4500"));
+    assertThat(rows.get(3)[1]).isEqualTo(new BigDecimal("123456.7800"));
+    assertThat(rows.get(4)[1]).isEqualTo(new BigDecimal("1.2345"));
+    assertThat(rows.get(5)[1]).isEqualTo(new BigDecimal("12.3000"));
+  }
+
+  private void verifyParquetSchema(Table table, MessageType expectedSchema) 
throws IOException {
+    try (CloseableIterable<FileScanTask> tasks = table.newScan().planFiles()) {
+      assertThat(tasks).isNotEmpty();
+
+      for (FileScanTask task : tasks) {
+        String path = task.file().location();
+
+        HadoopInputFile inputFile = HadoopInputFile.fromPath(new Path(path), 
new Configuration());
+
+        try (ParquetFileReader reader = ParquetFileReader.open(inputFile)) {
+          MessageType actualSchema = reader.getFileMetaData().getSchema();
+          assertThat(actualSchema).isEqualTo(expectedSchema);
+        }
+      }
+    }
+  }
+
+  private static MessageType parquetSchema(Type... variantTypes) {
+    return org.apache.parquet.schema.Types.buildMessage()
+        .required(PrimitiveType.PrimitiveTypeName.INT32)
+        .id(1)
+        .named("id")
+        .addFields(variantTypes)
+        .named("table");
+  }
+
+  private static GroupType variant(String name, int fieldId, Type.Repetition 
repetition) {
+    return org.apache.parquet.schema.Types.buildGroup(repetition)
+        .id(fieldId)
+        .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION))
+        .required(PrimitiveType.PrimitiveTypeName.BINARY)
+        .named("metadata")
+        .required(PrimitiveType.PrimitiveTypeName.BINARY)
+        .named("value")
+        .named(name);
+  }
+
+  private static GroupType variant(
+      String name, int fieldId, Type.Repetition repetition, Type shreddedType) 
{
+    checkShreddedType(shreddedType);
+    return org.apache.parquet.schema.Types.buildGroup(repetition)
+        .id(fieldId)
+        .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION))
+        .required(PrimitiveType.PrimitiveTypeName.BINARY)
+        .named("metadata")
+        .optional(PrimitiveType.PrimitiveTypeName.BINARY)
+        .named("value")
+        .addField(shreddedType)
+        .named(name);
+  }
+
+  private static Type shreddedPrimitive(PrimitiveType.PrimitiveTypeName 
primitive) {
+    return optional(primitive).named("typed_value");
+  }
+
+  private static Type shreddedPrimitive(
+      PrimitiveType.PrimitiveTypeName primitive, LogicalTypeAnnotation 
annotation) {
+    return optional(primitive).as(annotation).named("typed_value");
+  }
+
+  private static GroupType objectFields(GroupType... fields) {
+    for (GroupType fieldType : fields) {
+      checkField(fieldType);
+    }
+
+    return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.OPTIONAL)
+        .addFields(fields)
+        .named("typed_value");
+  }
+
+  private static GroupType field(String name, Type shreddedType) {
+    checkShreddedType(shreddedType);
+    return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.REQUIRED)
+        .optional(PrimitiveType.PrimitiveTypeName.BINARY)
+        .named("value")
+        .addField(shreddedType)
+        .named(name);
+  }
+
+  private static GroupType element(Type shreddedType) {
+    return field("element", shreddedType);
+  }
+
+  private static GroupType list(GroupType elementType) {
+    return 
org.apache.parquet.schema.Types.optionalList().element(elementType).named("typed_value");
+  }
+
+  private static void checkShreddedType(Type shreddedType) {
+    Preconditions.checkArgument(
+        shreddedType.getName().equals("typed_value"),
+        "Invalid shredded type name: %s should be typed_value",
+        shreddedType.getName());
+    Preconditions.checkArgument(
+        shreddedType.isRepetition(Type.Repetition.OPTIONAL),
+        "Invalid shredded type repetition: %s should be OPTIONAL",
+        shreddedType.getRepetition());
+  }
+
+  private static void checkField(GroupType fieldType) {
+    Preconditions.checkArgument(
+        fieldType.isRepetition(Type.Repetition.REQUIRED),
+        "Invalid field type repetition: %s should be REQUIRED",
+        fieldType.getRepetition());
+  }
+}

Reply via email to