This is an automated email from the ASF dual-hosted git repository.
pvary pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/main by this push:
new 9ec1b933ed Spark: Backport support writing shredded variant in
Iceberg-Spark (#16241)
9ec1b933ed is described below
commit 9ec1b933ed1d9253c6019773a624ba9b3d4d3c0a
Author: pvary <[email protected]>
AuthorDate: Thu May 7 16:48:58 2026 +0200
Spark: Backport support writing shredded variant in Iceberg-Spark (#16241)
backports #14297
---
.../apache/iceberg/spark/SparkSQLProperties.java | 8 +
.../org/apache/iceberg/spark/SparkWriteConf.java | 30 +
.../apache/iceberg/spark/SparkWriteOptions.java | 6 +
.../iceberg/spark/source/SparkFormatModels.java | 4 +-
.../source/SparkVariantShreddingAnalyzer.java | 69 ++
.../apache/iceberg/spark/TestSparkWriteConf.java | 85 ++
.../spark/variant/TestVariantShredding.java | 1101 ++++++++++++++++++++
7 files changed, 1302 insertions(+), 1 deletion(-)
diff --git
a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java
index b5b8602145..336aadd73c 100644
---
a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java
+++
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java
@@ -111,4 +111,12 @@ public class SparkSQLProperties {
// Prefix for custom snapshot properties
public static final String SNAPSHOT_PROPERTY_PREFIX =
"spark.sql.iceberg.snapshot-property.";
+
+ // Controls whether to shred variant columns during write operations
+ public static final String SHRED_VARIANTS =
"spark.sql.iceberg.shred-variants";
+
+ // Controls the buffer size for variant schema inference during writes
+ // This determines how many rows are buffered before inferring shredded
schema
+ public static final String VARIANT_INFERENCE_BUFFER_SIZE =
+ "spark.sql.iceberg.variant-inference-buffer-size";
}
diff --git
a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java
index aba7e4dda0..add12e6040 100644
---
a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java
+++
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java
@@ -33,6 +33,8 @@ import static
org.apache.iceberg.TableProperties.ORC_COMPRESSION;
import static org.apache.iceberg.TableProperties.ORC_COMPRESSION_STRATEGY;
import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION;
import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION_LEVEL;
+import static org.apache.iceberg.TableProperties.PARQUET_SHRED_VARIANTS;
+import static org.apache.iceberg.TableProperties.PARQUET_VARIANT_BUFFER_SIZE;
import static
org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE;
import java.util.Locale;
@@ -529,6 +531,14 @@ public class SparkWriteConf {
if (parquetCompressionLevel != null) {
writeProperties.put(PARQUET_COMPRESSION_LEVEL,
parquetCompressionLevel);
}
+ boolean shouldShredVariants = shredVariants();
+ writeProperties.put(PARQUET_SHRED_VARIANTS,
String.valueOf(shouldShredVariants));
+
+ // Add variant shredding configuration properties
+ if (shouldShredVariants) {
+ writeProperties.put(
+ PARQUET_VARIANT_BUFFER_SIZE,
String.valueOf(variantInferenceBufferSize()));
+ }
break;
case AVRO:
@@ -749,4 +759,24 @@ public class SparkWriteConf {
.defaultValue(DeleteGranularity.FILE)
.parse();
}
+
+ public boolean shredVariants() {
+ return confParser
+ .booleanConf()
+ .option(SparkWriteOptions.SHRED_VARIANTS)
+ .sessionConf(SparkSQLProperties.SHRED_VARIANTS)
+ .tableProperty(TableProperties.PARQUET_SHRED_VARIANTS)
+ .defaultValue(TableProperties.PARQUET_SHRED_VARIANTS_DEFAULT)
+ .parse();
+ }
+
+ public int variantInferenceBufferSize() {
+ return confParser
+ .intConf()
+ .option(SparkWriteOptions.VARIANT_INFERENCE_BUFFER_SIZE)
+ .sessionConf(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE)
+ .tableProperty(TableProperties.PARQUET_VARIANT_BUFFER_SIZE)
+ .defaultValue(TableProperties.PARQUET_VARIANT_BUFFER_SIZE_DEFAULT)
+ .parse();
+ }
}
diff --git
a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java
index 1be02feaf0..6c76b5c873 100644
---
a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java
+++
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java
@@ -86,4 +86,10 @@ public class SparkWriteOptions {
// Overrides the delete granularity
public static final String DELETE_GRANULARITY = "delete-granularity";
+
+ // Controls whether to shred variant columns during write operations
+ public static final String SHRED_VARIANTS = "shred-variants";
+
+ // Controls the buffer size for variant schema inference during writes
+ public static final String VARIANT_INFERENCE_BUFFER_SIZE =
"variant-inference-buffer-size";
}
diff --git
a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java
index 23fbe54a4b..5b7862116a 100644
---
a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java
+++
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFormatModels.java
@@ -51,7 +51,9 @@ public class SparkFormatModels {
StructType.class,
SparkParquetWriters::buildWriter,
(icebergSchema, fileSchema, engineSchema, idToConstant) ->
- SparkParquetReaders.buildReader(icebergSchema, fileSchema,
idToConstant)));
+ SparkParquetReaders.buildReader(icebergSchema, fileSchema,
idToConstant),
+ new SparkVariantShreddingAnalyzer(),
+ InternalRow::copy));
FormatModelRegistry.register(
ParquetFormatModel.create(
diff --git
a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java
new file mode 100644
index 0000000000..2c08c662c9
--- /dev/null
+++
b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkVariantShreddingAnalyzer.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.source;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.List;
+import org.apache.iceberg.parquet.VariantShreddingAnalyzer;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.variants.VariantMetadata;
+import org.apache.iceberg.variants.VariantValue;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.unsafe.types.VariantVal;
+
+/**
+ * Spark-specific implementation that extracts variant values from {@link
InternalRow} instances.
+ */
+class SparkVariantShreddingAnalyzer extends
VariantShreddingAnalyzer<InternalRow, StructType> {
+
+ SparkVariantShreddingAnalyzer() {}
+
+ @Override
+ protected int resolveColumnIndex(StructType sparkSchema, String columnName) {
+ try {
+ return sparkSchema.fieldIndex(columnName);
+ } catch (IllegalArgumentException e) {
+ return -1;
+ }
+ }
+
+ @Override
+ protected List<VariantValue> extractVariantValues(
+ List<InternalRow> bufferedRows, int variantFieldIndex) {
+ List<VariantValue> values = Lists.newArrayList();
+
+ for (InternalRow row : bufferedRows) {
+ if (!row.isNullAt(variantFieldIndex)) {
+ VariantVal variantVal = row.getVariant(variantFieldIndex);
+ if (variantVal != null) {
+ VariantValue variantValue =
+ VariantValue.from(
+ VariantMetadata.from(
+
ByteBuffer.wrap(variantVal.getMetadata()).order(ByteOrder.LITTLE_ENDIAN)),
+
ByteBuffer.wrap(variantVal.getValue()).order(ByteOrder.LITTLE_ENDIAN));
+ values.add(variantValue);
+ }
+ }
+ }
+
+ return values;
+ }
+}
diff --git
a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java
b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java
index c83b1b6e26..c5cfbe62b1 100644
---
a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java
+++
b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java
@@ -34,6 +34,7 @@ import static
org.apache.iceberg.TableProperties.ORC_COMPRESSION;
import static org.apache.iceberg.TableProperties.ORC_COMPRESSION_STRATEGY;
import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION;
import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION_LEVEL;
+import static org.apache.iceberg.TableProperties.PARQUET_SHRED_VARIANTS;
import static org.apache.iceberg.TableProperties.UPDATE_DISTRIBUTION_MODE;
import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE;
import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_HASH;
@@ -61,6 +62,7 @@ import org.apache.iceberg.deletes.DeleteGranularity;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.spark.sql.internal.SQLConf;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.TestTemplate;
@@ -340,6 +342,8 @@ public class TestSparkWriteConf extends TestBaseWithCatalog
{
TableProperties.DELETE_PARQUET_COMPRESSION,
"snappy"),
ImmutableMap.of(
+ PARQUET_SHRED_VARIANTS,
+ "false",
DELETE_PARQUET_COMPRESSION,
"zstd",
PARQUET_COMPRESSION,
@@ -461,6 +465,8 @@ public class TestSparkWriteConf extends TestBaseWithCatalog
{
PARQUET_COMPRESSION_LEVEL,
"5"),
ImmutableMap.of(
+ PARQUET_SHRED_VARIANTS,
+ "false",
DELETE_PARQUET_COMPRESSION,
"zstd",
PARQUET_COMPRESSION,
@@ -532,6 +538,8 @@ public class TestSparkWriteConf extends TestBaseWithCatalog
{
DELETE_PARQUET_COMPRESSION_LEVEL,
"6"),
ImmutableMap.of(
+ PARQUET_SHRED_VARIANTS,
+ "false",
DELETE_PARQUET_COMPRESSION,
"zstd",
PARQUET_COMPRESSION,
@@ -686,4 +694,81 @@ public class TestSparkWriteConf extends
TestBaseWithCatalog {
assertThat(writeConf.copyOnWriteDistributionMode(MERGE)).isEqualTo(expectedMode);
assertThat(writeConf.positionDeltaDistributionMode(MERGE)).isEqualTo(expectedMode);
}
+
+ @TestTemplate
+ public void testShredVariantsDefault() {
+ Table table = validationCatalog.loadTable(tableIdent);
+ SparkWriteConf writeConf = new SparkWriteConf(spark, table,
ImmutableMap.of());
+ assertThat(writeConf.shredVariants()).isFalse();
+ }
+
+ @TestTemplate
+ public void testVariantInferenceBufferSizeDefault() {
+ Table table = validationCatalog.loadTable(tableIdent);
+ SparkWriteConf writeConf = new SparkWriteConf(spark, table,
ImmutableMap.of());
+ assertThat(writeConf.variantInferenceBufferSize())
+ .isEqualTo(TableProperties.PARQUET_VARIANT_BUFFER_SIZE_DEFAULT);
+ }
+
+ @TestTemplate
+ public void testVariantInferenceBufferSizeTableProperty() {
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ table.updateProperties().set(TableProperties.PARQUET_VARIANT_BUFFER_SIZE,
"500").commit();
+
+ SparkWriteConf writeConf = new SparkWriteConf(spark, table,
ImmutableMap.of());
+ assertThat(writeConf.variantInferenceBufferSize()).isEqualTo(500);
+ }
+
+ @TestTemplate
+ public void testShredVariantsSessionOverridesTableProperty() {
+ Table table = validationCatalog.loadTable(tableIdent);
+ table.updateProperties().set(TableProperties.PARQUET_SHRED_VARIANTS,
"false").commit();
+
+ withSQLConf(
+ ImmutableMap.of(SparkSQLProperties.SHRED_VARIANTS, "true"),
+ () -> {
+ SparkWriteConf writeConf = new SparkWriteConf(spark, table,
ImmutableMap.of());
+ assertThat(writeConf.shredVariants()).isTrue();
+ });
+ }
+
+ @TestTemplate
+ public void testShredVariantsWriteOptionOverridesSessionConf() {
+ withSQLConf(
+ ImmutableMap.of(SparkSQLProperties.SHRED_VARIANTS, "false"),
+ () -> {
+ Table table = validationCatalog.loadTable(tableIdent);
+ SparkWriteConf writeConf =
+ new SparkWriteConf(
+ spark,
+ table,
+ new CaseInsensitiveStringMap(
+ ImmutableMap.of(SparkWriteOptions.SHRED_VARIANTS,
"true")));
+ assertThat(writeConf.shredVariants()).isTrue();
+ });
+ }
+
+ @TestTemplate
+ public void testVariantInferenceBufferSizeSessionConf() {
+ withSQLConf(
+ ImmutableMap.of(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE,
"250"),
+ () -> {
+ Table table = validationCatalog.loadTable(tableIdent);
+ SparkWriteConf writeConf = new SparkWriteConf(spark, table,
ImmutableMap.of());
+ assertThat(writeConf.variantInferenceBufferSize()).isEqualTo(250);
+ });
+ }
+
+ @TestTemplate
+ public void testWritePropertiesIncludeVariantShredding() {
+ Table table = validationCatalog.loadTable(tableIdent);
+ table.updateProperties().set(TableProperties.PARQUET_SHRED_VARIANTS,
"true").commit();
+ table.updateProperties().set(TableProperties.PARQUET_VARIANT_BUFFER_SIZE,
"200").commit();
+
+ SparkWriteConf writeConf = new SparkWriteConf(spark, table,
ImmutableMap.of());
+ Map<String, String> writeProperties = writeConf.writeProperties();
+ assertThat(writeProperties).containsEntry(PARQUET_SHRED_VARIANTS, "true");
+
assertThat(writeProperties).containsEntry(TableProperties.PARQUET_VARIANT_BUFFER_SIZE,
"200");
+ }
}
diff --git
a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java
b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java
new file mode 100644
index 0000000000..8cdcf22e58
--- /dev/null
+++
b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java
@@ -0,0 +1,1101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.variant;
+
+import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS;
+import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES;
+import static org.apache.parquet.schema.Types.optional;
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.net.InetAddress;
+import java.util.List;
+import java.util.Map;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.iceberg.FileScanTask;
+import org.apache.iceberg.Parameters;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.TableProperties;
+import org.apache.iceberg.io.CloseableIterable;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.spark.CatalogTestBase;
+import org.apache.iceberg.spark.SparkCatalogConfig;
+import org.apache.iceberg.spark.SparkSQLProperties;
+import org.apache.iceberg.types.Types;
+import org.apache.iceberg.variants.Variant;
+import org.apache.parquet.hadoop.ParquetFileReader;
+import org.apache.parquet.hadoop.util.HadoopInputFile;
+import org.apache.parquet.schema.GroupType;
+import org.apache.parquet.schema.LogicalTypeAnnotation;
+import org.apache.parquet.schema.MessageType;
+import org.apache.parquet.schema.PrimitiveType;
+import org.apache.parquet.schema.Type;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
+import org.apache.spark.sql.internal.SQLConf;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.TestTemplate;
+
+public class TestVariantShredding extends CatalogTestBase {
+
+ private static final Schema SCHEMA =
+ new Schema(
+ Types.NestedField.required(1, "id", Types.IntegerType.get()),
+ Types.NestedField.optional(2, "address", Types.VariantType.get()));
+
+ private static final Schema SCHEMA2 =
+ new Schema(
+ Types.NestedField.required(1, "id", Types.IntegerType.get()),
+ Types.NestedField.optional(2, "address", Types.VariantType.get()),
+ Types.NestedField.optional(3, "metadata", Types.VariantType.get()));
+
+ @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}")
+ protected static Object[][] parameters() {
+ return new Object[][] {
+ {
+ SparkCatalogConfig.HADOOP.catalogName(),
+ SparkCatalogConfig.HADOOP.implementation(),
+ SparkCatalogConfig.HADOOP.properties()
+ },
+ };
+ }
+
+ @BeforeAll
+ public static void startMetastoreAndSpark() {
+ // First call parent to initialize metastore and spark with local[2]
+ CatalogTestBase.startMetastoreAndSpark();
+
+ // Now stop and recreate spark with local[1] to write all rows to a single
file
+ if (spark != null) {
+ spark.stop();
+ }
+
+ spark =
+ SparkSession.builder()
+ .master("local[1]") // Use one thread to write the rows to a
single parquet file
+ .config("spark.driver.host",
InetAddress.getLoopbackAddress().getHostAddress())
+ .config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic")
+ .config("spark.hadoop." + METASTOREURIS.varname,
hiveConf.get(METASTOREURIS.varname))
+
.config("spark.sql.legacy.respectNullabilityInTextDatasetConversion", "true")
+ .config(DISABLE_UI)
+ .enableHiveSupport()
+ .getOrCreate();
+
+ sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext());
+ }
+
+ @BeforeEach
+ public void before() {
+ super.before();
+ validationCatalog.createTable(
+ tableIdent, SCHEMA, null, Map.of(TableProperties.FORMAT_VERSION, "3"));
+ }
+
+ @AfterEach
+ public void after() {
+ spark.conf().unset(SparkSQLProperties.SHRED_VARIANTS);
+ spark.conf().unset(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE);
+ validationCatalog.dropTable(tableIdent, true);
+ }
+
+ @TestTemplate
+ public void testVariantShreddingDisabled() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "false");
+
+ String values = "(1, parse_json('{\"city\": \"NYC\", \"zip\": 10001}')),
(2, null)";
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ GroupType address = variant("address", 2, Type.Repetition.OPTIONAL);
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+ }
+
+ @TestTemplate
+ public void testExcludingNullValue() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+ String values =
+ """
+ (1, parse_json('{"name": "Alice", "age": 30, "dummy": null}')),\
+ (2, parse_json('{"name": "Bob", "age": 25}')),\
+ (3, parse_json('{"name": "Charlie", "age": 35}'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ GroupType name =
+ field(
+ "name",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.BINARY,
LogicalTypeAnnotation.stringType()));
+ GroupType age =
+ field(
+ "age",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.INT32,
LogicalTypeAnnotation.intType(8, true)));
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED,
objectFields(age, name));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+ }
+
+ @TestTemplate
+ public void testInconsistentType() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+ String values =
+ """
+ (1, parse_json('{"age": "25"}')),\
+ (2, parse_json('{"age": 30}')),\
+ (3, parse_json('{"age": "35"}'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ GroupType age =
+ field(
+ "age",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.BINARY,
LogicalTypeAnnotation.stringType()));
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED,
objectFields(age));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+
+ List<Object[]> rows =
+ sql("SELECT variant_get(address, '$.age', 'int') FROM %s WHERE id =
2", tableName);
+ assertThat(rows).hasSize(1);
+ assertThat(rows.get(0)[0]).isEqualTo(30);
+ }
+
+ @TestTemplate
+ public void testPrimitiveType() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+ String values = "(1, parse_json('123')), (2, parse_json('456')), (3,
parse_json('789'))";
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ GroupType address =
+ variant(
+ "address",
+ 2,
+ Type.Repetition.REQUIRED,
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.INT32,
LogicalTypeAnnotation.intType(16, true)));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+ }
+
+ @TestTemplate
+ public void testPrimitiveDecimalType() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+ String values =
+ "(1, parse_json('123.56')), (2, parse_json('\"abc\"')), (3,
parse_json('12.56'))";
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ GroupType address =
+ variant(
+ "address",
+ 2,
+ Type.Repetition.REQUIRED,
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.INT32,
LogicalTypeAnnotation.decimalType(2, 5)));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+ }
+
+ @TestTemplate
+ public void testBooleanType() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+ String values =
+ """
+ (1, parse_json('{"active": true}')),\
+ (2, parse_json('{"active": false}')),\
+ (3, parse_json('{"active": true}'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ GroupType active = field("active",
shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN));
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED,
objectFields(active));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+ }
+
+ @TestTemplate
+ public void testDecimalTypeWithInconsistentScales() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+ String values =
+ """
+ (1, parse_json('{"price": 123.456789}')),\
+ (2, parse_json('{"price": 678.90}')),\
+ (3, parse_json('{"price": 999.99}'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ GroupType price =
+ field(
+ "price",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.INT32,
LogicalTypeAnnotation.decimalType(6, 9)));
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED,
objectFields(price));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+ }
+
+ @TestTemplate
+ public void testDecimalTypeWithConsistentScales() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+ String values =
+ """
+ (1, parse_json('{"price": 123.45}')),\
+ (2, parse_json('{"price": 678.90}')),\
+ (3, parse_json('{"price": 999.99}'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ GroupType price =
+ field(
+ "price",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.INT32,
LogicalTypeAnnotation.decimalType(2, 5)));
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED,
objectFields(price));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+ }
+
+ @TestTemplate
+ public void testArrayType() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+ String values =
+ """
+ (1, parse_json('["java", "scala", "python"]')),\
+ (2, parse_json('["rust", "go"]')),\
+ (3, parse_json('["javascript"]'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ GroupType arr =
+ list(
+ element(
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.BINARY,
LogicalTypeAnnotation.stringType())));
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED, arr);
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+ }
+
+ @TestTemplate
+ public void testNestedArrayType() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+ String values =
+ """
+ (1, parse_json('{"tags": ["java", "scala", "python"]}')),\
+ (2, parse_json('{"tags": ["rust", "go"]}')),\
+ (3, parse_json('{"tags": ["javascript"]}'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ GroupType tags =
+ field(
+ "tags",
+ list(
+ element(
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.BINARY,
+ LogicalTypeAnnotation.stringType()))));
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED,
objectFields(tags));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+ }
+
+ @TestTemplate
+ public void testNestedObjectType() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+ String values =
+ """
+ (1, parse_json('{"location": {"city": "Seattle", "zip": 98101},
"tags": ["java", "scala", "python"]}')),\
+ (2, parse_json('{"location": {"city": "Portland", "zip":
97201}}')),\
+ (3, parse_json('{"location": {"city": "NYC", "zip": 10001}}'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ GroupType city =
+ field(
+ "city",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.BINARY,
LogicalTypeAnnotation.stringType()));
+ GroupType zip =
+ field(
+ "zip",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.INT32,
LogicalTypeAnnotation.intType(32, true)));
+ GroupType location = field("location", objectFields(city, zip));
+ GroupType tags =
+ field(
+ "tags",
+ list(
+ element(
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.BINARY,
+ LogicalTypeAnnotation.stringType()))));
+
+ GroupType address =
+ variant("address", 2, Type.Repetition.REQUIRED, objectFields(location,
tags));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+ }
+
+ @TestTemplate
+ public void testLazyInitializationWithBufferedRows() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+ spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "5");
+
+ String values =
+ """
+ (1, parse_json('{"name": "Alice", "age": 30}')),\
+ (2, parse_json('{"name": "Bob", "age": 25}')),\
+ (3, parse_json('{"name": "Charlie", "age": 35}')),\
+ (4, parse_json('{"name": "David", "age": 28}')),\
+ (5, parse_json('{"name": "Eve", "age": 32}')),\
+ (6, parse_json('{"name": "Frank", "age": 40}')),\
+ (7, parse_json('{"name": "Grace", "age": 27}'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ GroupType name =
+ field(
+ "name",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.BINARY,
LogicalTypeAnnotation.stringType()));
+ GroupType age =
+ field(
+ "age",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.INT32,
LogicalTypeAnnotation.intType(8, true)));
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED,
objectFields(age, name));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+
+ long rowCount = spark.read().format("iceberg").load(tableName).count();
+ assertThat(rowCount).isEqualTo(7);
+ }
+
+ @TestTemplate
+ public void testMultipleRowGroups() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+ spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3");
+
+ int numRows = 1000;
+ StringBuilder valuesBuilder = new StringBuilder();
+ for (int i = 1; i <= numRows; i++) {
+ if (i > 1) {
+ valuesBuilder.append(", ");
+ }
+ valuesBuilder.append(
+ String.format("(%d, parse_json('{\"name\": \"User%d\", \"age\":
%d}'))", i, i, 20 + i));
+ }
+ sql(
+ "ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')",
+ tableName, PARQUET_ROW_GROUP_SIZE_BYTES, 1024);
+ sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString());
+
+ GroupType name =
+ field(
+ "name",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.BINARY,
LogicalTypeAnnotation.stringType()));
+ GroupType age =
+ field(
+ "age",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.INT32,
LogicalTypeAnnotation.intType(8, true)));
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED,
objectFields(age, name));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+
+ long rowCount = spark.read().format("iceberg").load(tableName).count();
+ assertThat(rowCount).isEqualTo(numRows);
+ }
+
+ @TestTemplate
+ public void testColumnIndexTruncateLength() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+ spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3");
+
+ int customTruncateLength = 10;
+ sql(
+ "ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')",
+ tableName, "parquet.columnindex.truncate.length",
customTruncateLength);
+
+ StringBuilder valuesBuilder = new StringBuilder();
+ for (int i = 1; i <= 10; i++) {
+ if (i > 1) {
+ valuesBuilder.append(", ");
+ }
+ String longValue = "A".repeat(20);
+ valuesBuilder.append(
+ String.format(
+ "(%d, parse_json('{\"description\": \"%s\", \"id\": %d}'))", i,
longValue, i));
+ }
+ sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString());
+
+ GroupType description =
+ field(
+ "description",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.BINARY,
LogicalTypeAnnotation.stringType()));
+ GroupType id =
+ field(
+ "id",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.INT32,
LogicalTypeAnnotation.intType(8, true)));
+ GroupType address =
+ variant("address", 2, Type.Repetition.REQUIRED,
objectFields(description, id));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+
+ long rowCount = spark.read().format("iceberg").load(tableName).count();
+ assertThat(rowCount).isEqualTo(10);
+ }
+
+ @TestTemplate
+ public void testIntegerFamilyPromotion() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+ // Mix of INT8, INT16, INT32, INT64 - should promote to INT64
+ String values =
+ """
+ (1, parse_json('{"value": 10}')),\
+ (2, parse_json('{"value": 1000}')),\
+ (3, parse_json('{"value": 100000}')),\
+ (4, parse_json('{"value": 10000000000}'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ GroupType value =
+ field(
+ "value",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.INT64,
LogicalTypeAnnotation.intType(64, true)));
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED,
objectFields(value));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+ }
+
+ @TestTemplate
+ public void testDecimalFamilyPromotion() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+ // Test that they get promoted to the most capable decimal type observed
+ String values =
+ """
+ (1, parse_json('{"value": 1.5}')),\
+ (2, parse_json('{"value": 123.456789}')),\
+ (3, parse_json('{"value": 123456789123456.789}'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ GroupType value =
+ field(
+ "value",
+ optional(PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY)
+ .length(16)
+ .as(LogicalTypeAnnotation.decimalType(6, 21))
+ .named("typed_value"));
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED,
objectFields(value));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+ }
+
+ @TestTemplate
+ public void testDataRoundTripWithShredding() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+ String values =
+ """
+ (1, parse_json('{"name": "Alice", "age": 30}')),\
+ (2, parse_json('{"name": "Bob", "age": 25}')),\
+ (3, parse_json('{"name": "Charlie", "age": 35}'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ GroupType name =
+ field(
+ "name",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.BINARY,
LogicalTypeAnnotation.stringType()));
+ GroupType age =
+ field(
+ "age",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.INT32,
LogicalTypeAnnotation.intType(8, true)));
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED,
objectFields(age, name));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+
+ // Verify that we can read the data back correctly
+ List<Object[]> rows =
+ sql(
+ "SELECT id, variant_get(address, '$.name', 'string'),"
+ + " variant_get(address, '$.age', 'int')"
+ + " FROM %s ORDER BY id",
+ tableName);
+ assertThat(rows).hasSize(3);
+ assertThat(rows.get(0)[0]).isEqualTo(1);
+ assertThat(rows.get(0)[1]).isEqualTo("Alice");
+ assertThat(rows.get(0)[2]).isEqualTo(30);
+ assertThat(rows.get(1)[0]).isEqualTo(2);
+ assertThat(rows.get(1)[1]).isEqualTo("Bob");
+ assertThat(rows.get(1)[2]).isEqualTo(25);
+ assertThat(rows.get(2)[0]).isEqualTo(3);
+ assertThat(rows.get(2)[1]).isEqualTo("Charlie");
+ assertThat(rows.get(2)[2]).isEqualTo(35);
+ }
+
+ @TestTemplate
+ public void testMultipleVariantsWithShredding() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+ // Recreate table with SCHEMA2 (address + metadata variant columns)
+ validationCatalog.dropTable(tableIdent, true);
+ validationCatalog.createTable(
+ tableIdent, SCHEMA2, null, Map.of(TableProperties.FORMAT_VERSION,
"3"));
+
+ String values =
+ """
+ (1, parse_json('{"city": "NYC"}'), parse_json('{"source":
"web"}')),\
+ (2, parse_json('{"city": "LA"}'), parse_json('{"source":
"app"}')),\
+ (3, parse_json('{"city": "SF"}'), parse_json('{"source":
"api"}'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ GroupType city =
+ field(
+ "city",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.BINARY,
LogicalTypeAnnotation.stringType()));
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED,
objectFields(city));
+
+ GroupType source =
+ field(
+ "source",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.BINARY,
LogicalTypeAnnotation.stringType()));
+ GroupType metadata = variant("metadata", 3, Type.Repetition.REQUIRED,
objectFields(source));
+ MessageType expectedSchema = parquetSchema(address, metadata);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+ }
+
+ @TestTemplate
+ public void testVariantWithNullValues() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+ String values =
+ """
+ (1, parse_json('null')),\
+ (2, parse_json('null')),\
+ (3, parse_json('null'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED);
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+ }
+
+ @TestTemplate
+ public void testArrayOfNullElementsWithShredding() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+ sql(
+ "INSERT INTO %s VALUES (1, parse_json('[null, null, null]')), "
+ + "(2, parse_json('[null]'))",
+ tableName);
+
+ // Array elements are all null, element type is null, falls back to
unshredded
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED);
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+ }
+
+ @TestTemplate
+ public void testMixedNullAndNonNullVariantValues() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+ String values =
+ """
+ (1, parse_json('{"name": "Alice", "age": 30}')),\
+ (2, null),\
+ (3, parse_json('{"name": "Charlie", "age": 35}'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ GroupType name =
+ field(
+ "name",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.BINARY,
LogicalTypeAnnotation.stringType()));
+ GroupType age =
+ field(
+ "age",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.INT32,
LogicalTypeAnnotation.intType(8, true)));
+ GroupType address = variant("address", 2, Type.Repetition.OPTIONAL,
objectFields(age, name));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+
+ long rowCount = spark.read().format("iceberg").load(tableName).count();
+ assertThat(rowCount).isEqualTo(3);
+ }
+
+ @TestTemplate
+ public void testWriteOptionOverridesSessionConfig() throws IOException,
NoSuchTableException {
+ // Disable shredding at session level
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "false");
+
+ // Enable shredding via per-write option
+ String query =
+ "SELECT 1 as id, parse_json('{\"name\": \"Alice\", \"age\": 30}') as
address"
+ + " UNION ALL SELECT 2, parse_json('{\"name\": \"Bob\", \"age\":
25}')"
+ + " UNION ALL SELECT 3, parse_json('{\"name\": \"Charlie\",
\"age\": 35}')";
+ spark.sql(query).writeTo(tableName).option("shred-variants",
"true").append();
+
+ GroupType name =
+ field(
+ "name",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.BINARY,
LogicalTypeAnnotation.stringType()));
+ GroupType age =
+ field(
+ "age",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.INT32,
LogicalTypeAnnotation.intType(8, true)));
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED,
objectFields(age, name));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+ }
+
+ @TestTemplate
+ public void testInfrequentFieldPruning() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+ spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "11");
+
+ StringBuilder valuesBuilder = new StringBuilder();
+ for (int i = 1; i <= 11; i++) {
+ if (i > 1) {
+ valuesBuilder.append(", ");
+ }
+ if (i == 1) {
+ // Only the first row has rare_field
+ valuesBuilder.append(
+ String.format(
+ "(%d, parse_json('{\"name\": \"User%d\", \"rare_field\":
\"rare\"}'))", i, i));
+ } else {
+ valuesBuilder.append(String.format("(%d, parse_json('{\"name\":
\"User%d\"}'))", i, i));
+ }
+ }
+ sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString());
+
+ // rare_field appears in 1/11 rows, should be pruned
+ // name appears in 11/11 rows and should be kept
+ GroupType name =
+ field(
+ "name",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.BINARY,
LogicalTypeAnnotation.stringType()));
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED,
objectFields(name));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+ }
+
+ @TestTemplate
+ public void testMixedTypeTieBreaking() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+ spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "10");
+
+ StringBuilder valuesBuilder = new StringBuilder();
+ for (int i = 1; i <= 10; i++) {
+ if (i > 1) {
+ valuesBuilder.append(", ");
+ }
+ if (i <= 5) {
+ valuesBuilder.append(String.format("(%d, parse_json('{\"val\":
%d}'))", i, i));
+ } else {
+ valuesBuilder.append(String.format("(%d, parse_json('{\"val\":
\"text%d\"}'))", i, i));
+ }
+ }
+ sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString());
+
+ // 5 ints + 5 strings is a tie so STRING wins (higher TIE_BREAK_PRIORITY)
+ GroupType val =
+ field(
+ "val",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.BINARY,
LogicalTypeAnnotation.stringType()));
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED,
objectFields(val));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+
+ // Verify data round-trips correctly
+ List<Object[]> rows =
+ sql("SELECT id, variant_get(address, '$.val', 'string') FROM %s ORDER
BY id", tableName);
+ assertThat(rows).hasSize(10);
+ assertThat(rows.get(0)[1]).isEqualTo("1");
+ assertThat(rows.get(5)[1]).isEqualTo("text6");
+ }
+
+ @TestTemplate
+ public void testFieldOnlyAfterBuffer() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+ spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3");
+
+ String values =
+ """
+ (1, parse_json('{"name": "Alice"}')),\
+ (2, parse_json('{"name": "Bob"}')),\
+ (3, parse_json('{"name": "Charlie"}')),\
+ (4, parse_json('{"name": "David", "score": 95}')),\
+ (5, parse_json('{"name": "Eve", "score": 88}')),\
+ (6, parse_json('{"name": "Frank", "score": 72}')),\
+ (7, parse_json('{"name": "Grace", "score": 91}'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ // Schema is determined from buffer (rows 1-3) which only has "name".
+ // "score" is not shredded
+ GroupType name =
+ field(
+ "name",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.BINARY,
LogicalTypeAnnotation.stringType()));
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED,
objectFields(name));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+
+ // Verify all data round-trips despite "score" not being shredded
+ List<Object[]> rows =
+ sql(
+ "SELECT id, variant_get(address, '$.name', 'string'),"
+ + " variant_get(address, '$.score', 'int')"
+ + " FROM %s ORDER BY id",
+ tableName);
+ assertThat(rows).hasSize(7);
+ assertThat(rows.get(0)[1]).isEqualTo("Alice");
+ assertThat(rows.get(0)[2]).isNull();
+ assertThat(rows.get(3)[1]).isEqualTo("David");
+ assertThat(rows.get(3)[2]).isEqualTo(95);
+ assertThat(rows.get(6)[1]).isEqualTo("Grace");
+ assertThat(rows.get(6)[2]).isEqualTo(91);
+ }
+
+ @TestTemplate
+ public void testCrossFileDifferentShreddedType() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+ spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3");
+
+ // File 1: "score" is always integer → shredded as INT8
+ String batch1 =
+ """
+ (1, parse_json('{"score": 95}')),\
+ (2, parse_json('{"score": 88}')),\
+ (3, parse_json('{"score": 72}'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, batch1);
+
+ // Verify file 1 schema: score shredded as INT8
+ Table table = validationCatalog.loadTable(tableIdent);
+ GroupType scoreInt =
+ field(
+ "score",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.INT32,
LogicalTypeAnnotation.intType(8, true)));
+ MessageType expectedSchema1 =
+ parquetSchema(variant("address", 2, Type.Repetition.REQUIRED,
objectFields(scoreInt)));
+ verifyParquetSchema(table, expectedSchema1);
+
+ // File 2: "score" is always string → shredded as STRING
+ String batch2 =
+ """
+ (4, parse_json('{"score": "high"}')),\
+ (5, parse_json('{"score": "medium"}')),\
+ (6, parse_json('{"score": "low"}'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, batch2);
+
+ // Query across both files, reader must handle different shredded types
+ List<Object[]> rows =
+ sql("SELECT id, variant_get(address, '$.score', 'string') FROM %s
ORDER BY id", tableName);
+ assertThat(rows).hasSize(6);
+ assertThat(rows.get(0)[1]).isEqualTo("95");
+ assertThat(rows.get(1)[1]).isEqualTo("88");
+ assertThat(rows.get(3)[1]).isEqualTo("high");
+ assertThat(rows.get(5)[1]).isEqualTo("low");
+ }
+
+ @TestTemplate
+ public void testAllNullVariantColumn() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+
+ sql("INSERT INTO %s VALUES (1, null), (2, null), (3, null)", tableName);
+
+ // All variant values are SQL NULL, so no shredding should occur
+ Table table = validationCatalog.loadTable(tableIdent);
+ MessageType expectedSchema = parquetSchema(variant("address", 2,
Type.Repetition.OPTIONAL));
+ verifyParquetSchema(table, expectedSchema);
+
+ List<Object[]> rows = sql("SELECT id, address FROM %s ORDER BY id",
tableName);
+ assertThat(rows).hasSize(3);
+ assertThat(rows.get(0)[1]).isNull();
+ assertThat(rows.get(1)[1]).isNull();
+ assertThat(rows.get(2)[1]).isNull();
+ }
+
+ @TestTemplate
+ public void testBufferSizeOne() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+ spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "1");
+
+ sql(
+ """
+ INSERT INTO %s VALUES
+ (1, parse_json('{"name": "Alice", "age": 30}')),
+ (2, parse_json('{"name": "Bob", "age": 25}')),
+ (3, parse_json('{"name": "Charlie", "age": 35}'))
+ """,
+ tableName);
+
+ // Schema inferred from first row only, should still shred name and age
+ GroupType age =
+ field(
+ "age",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.INT32,
LogicalTypeAnnotation.intType(8, true)));
+ GroupType name =
+ field(
+ "name",
+ shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName.BINARY,
LogicalTypeAnnotation.stringType()));
+ GroupType address = variant("address", 2, Type.Repetition.REQUIRED,
objectFields(age, name));
+ MessageType expectedSchema = parquetSchema(address);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+ verifyParquetSchema(table, expectedSchema);
+
+ List<Object[]> rows =
+ sql("SELECT id, variant_get(address, '$.name', 'string') FROM %s ORDER
BY id", tableName);
+ assertThat(rows).hasSize(3);
+ assertThat(rows.get(0)[1]).isEqualTo("Alice");
+ assertThat(rows.get(2)[1]).isEqualTo("Charlie");
+ }
+
+ @TestTemplate
+ public void testDecimalFallbackAfterBuffer() throws IOException {
+ spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true");
+ spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3");
+
+ // Buffer: scale=2, 3 integer digits -> DECIMAL(5,2)
+ // Row 4: precision overflow -> fallback to value field
+ // Row 5: scale overflow -> fallback to value field
+ // Row 6: fits typed column, scale widened from 1 to 2 via setScale
+ String values =
+ """
+ (1, parse_json('{"val": 123.45}')),\
+ (2, parse_json('{"val": 678.90}')),\
+ (3, parse_json('{"val": 999.99}')),\
+ (4, parse_json('{"val": 123456.78}')),\
+ (5, parse_json('{"val": 1.2345}')),\
+ (6, parse_json('{"val": 12.3}'))\
+ """;
+ sql("INSERT INTO %s VALUES %s", tableName, values);
+
+ List<Object[]> rows =
+ sql(
+ "SELECT id, variant_get(address, '$.val', 'decimal(10,4)') FROM %s
ORDER BY id",
+ tableName);
+ assertThat(rows).hasSize(6);
+ assertThat(rows.get(0)[1]).isEqualTo(new BigDecimal("123.4500"));
+ assertThat(rows.get(3)[1]).isEqualTo(new BigDecimal("123456.7800"));
+ assertThat(rows.get(4)[1]).isEqualTo(new BigDecimal("1.2345"));
+ assertThat(rows.get(5)[1]).isEqualTo(new BigDecimal("12.3000"));
+ }
+
+ private void verifyParquetSchema(Table table, MessageType expectedSchema)
throws IOException {
+ try (CloseableIterable<FileScanTask> tasks = table.newScan().planFiles()) {
+ assertThat(tasks).isNotEmpty();
+
+ for (FileScanTask task : tasks) {
+ String path = task.file().location();
+
+ HadoopInputFile inputFile = HadoopInputFile.fromPath(new Path(path),
new Configuration());
+
+ try (ParquetFileReader reader = ParquetFileReader.open(inputFile)) {
+ MessageType actualSchema = reader.getFileMetaData().getSchema();
+ assertThat(actualSchema).isEqualTo(expectedSchema);
+ }
+ }
+ }
+ }
+
+ private static MessageType parquetSchema(Type... variantTypes) {
+ return org.apache.parquet.schema.Types.buildMessage()
+ .required(PrimitiveType.PrimitiveTypeName.INT32)
+ .id(1)
+ .named("id")
+ .addFields(variantTypes)
+ .named("table");
+ }
+
+ private static GroupType variant(String name, int fieldId, Type.Repetition
repetition) {
+ return org.apache.parquet.schema.Types.buildGroup(repetition)
+ .id(fieldId)
+ .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION))
+ .required(PrimitiveType.PrimitiveTypeName.BINARY)
+ .named("metadata")
+ .required(PrimitiveType.PrimitiveTypeName.BINARY)
+ .named("value")
+ .named(name);
+ }
+
+ private static GroupType variant(
+ String name, int fieldId, Type.Repetition repetition, Type shreddedType)
{
+ checkShreddedType(shreddedType);
+ return org.apache.parquet.schema.Types.buildGroup(repetition)
+ .id(fieldId)
+ .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION))
+ .required(PrimitiveType.PrimitiveTypeName.BINARY)
+ .named("metadata")
+ .optional(PrimitiveType.PrimitiveTypeName.BINARY)
+ .named("value")
+ .addField(shreddedType)
+ .named(name);
+ }
+
+ private static Type shreddedPrimitive(PrimitiveType.PrimitiveTypeName
primitive) {
+ return optional(primitive).named("typed_value");
+ }
+
+ private static Type shreddedPrimitive(
+ PrimitiveType.PrimitiveTypeName primitive, LogicalTypeAnnotation
annotation) {
+ return optional(primitive).as(annotation).named("typed_value");
+ }
+
+ private static GroupType objectFields(GroupType... fields) {
+ for (GroupType fieldType : fields) {
+ checkField(fieldType);
+ }
+
+ return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.OPTIONAL)
+ .addFields(fields)
+ .named("typed_value");
+ }
+
+ private static GroupType field(String name, Type shreddedType) {
+ checkShreddedType(shreddedType);
+ return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.REQUIRED)
+ .optional(PrimitiveType.PrimitiveTypeName.BINARY)
+ .named("value")
+ .addField(shreddedType)
+ .named(name);
+ }
+
+ private static GroupType element(Type shreddedType) {
+ return field("element", shreddedType);
+ }
+
+ private static GroupType list(GroupType elementType) {
+ return
org.apache.parquet.schema.Types.optionalList().element(elementType).named("typed_value");
+ }
+
+ private static void checkShreddedType(Type shreddedType) {
+ Preconditions.checkArgument(
+ shreddedType.getName().equals("typed_value"),
+ "Invalid shredded type name: %s should be typed_value",
+ shreddedType.getName());
+ Preconditions.checkArgument(
+ shreddedType.isRepetition(Type.Repetition.OPTIONAL),
+ "Invalid shredded type repetition: %s should be OPTIONAL",
+ shreddedType.getRepetition());
+ }
+
+ private static void checkField(GroupType fieldType) {
+ Preconditions.checkArgument(
+ fieldType.isRepetition(Type.Repetition.REQUIRED),
+ "Invalid field type repetition: %s should be REQUIRED",
+ fieldType.getRepetition());
+ }
+}