This is an automated email from the ASF dual-hosted git repository.
blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/parquet-java.git
The following commit(s) were added to refs/heads/master by this push:
new 4aa2ea918 GH-3223: Implement Variant parquet writer (#3221)
4aa2ea918 is described below
commit 4aa2ea91863274aebb1eded243ce275912c16010
Author: David Cashman <[email protected]>
AuthorDate: Fri Jun 27 15:38:57 2025 -0400
GH-3223: Implement Variant parquet writer (#3221)
---
.../org/apache/parquet/avro/AvroWriteSupport.java | 79 ++-
.../java/org/apache/parquet/avro/AvroTestUtil.java | 33 ++
.../org/apache/parquet/avro/TestReadVariant.java | 30 +-
.../org/apache/parquet/avro/TestWriteVariant.java | 612 +++++++++++++++++++++
.../org/apache/parquet/variant/VariantBuilder.java | 2 +-
.../org/apache/parquet/variant/VariantUtil.java | 70 +++
.../apache/parquet/variant/VariantValueWriter.java | 375 +++++++++++++
7 files changed, 1168 insertions(+), 33 deletions(-)
diff --git
a/parquet-avro/src/main/java/org/apache/parquet/avro/AvroWriteSupport.java
b/parquet-avro/src/main/java/org/apache/parquet/avro/AvroWriteSupport.java
index 53fc3d59c..5e62e9821 100644
--- a/parquet-avro/src/main/java/org/apache/parquet/avro/AvroWriteSupport.java
+++ b/parquet-avro/src/main/java/org/apache/parquet/avro/AvroWriteSupport.java
@@ -42,9 +42,12 @@ import org.apache.parquet.hadoop.util.ConfigurationUtil;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.io.api.RecordConsumer;
import org.apache.parquet.schema.GroupType;
+import org.apache.parquet.schema.LogicalTypeAnnotation;
import
org.apache.parquet.schema.LogicalTypeAnnotation.UUIDLogicalTypeAnnotation;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.Type;
+import org.apache.parquet.variant.Variant;
+import org.apache.parquet.variant.VariantValueWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -181,9 +184,79 @@ public class AvroWriteSupport<T> extends WriteSupport<T> {
}
private void writeRecord(GroupType schema, Schema avroSchema, Object record)
{
- recordConsumer.startGroup();
- writeRecordFields(schema, avroSchema, record);
- recordConsumer.endGroup();
+ if (schema.getLogicalTypeAnnotation() instanceof
LogicalTypeAnnotation.VariantLogicalTypeAnnotation) {
+ writeVariantFields(schema, avroSchema, record);
+ } else {
+ recordConsumer.startGroup();
+ writeRecordFields(schema, avroSchema, record);
+ recordConsumer.endGroup();
+ }
+ }
+
+ // Return true if schema and avroSchema have the same field names, in the
same order.
+ private static boolean schemaMatches(GroupType schema, Schema avroSchema) {
+ List<Schema.Field> avroFields = avroSchema.getFields();
+ if (schema.getFieldCount() != avroFields.size()) {
+ return false;
+ }
+
+ for (int i = 0; i < avroFields.size(); i += 1) {
+ if (!avroFields.get(i).name().equals(schema.getFieldName(i))) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ private void writeVariantFields(GroupType schema, Schema avroSchema, Object
record) {
+ List<Type> fields = schema.getFields();
+ List<Schema.Field> avroFields = avroSchema.getFields();
+
+ if (schemaMatches(schema, avroSchema)) {
+ // If the Avro schema matches the Parquet schema, the shredding matches
and writeRecordFields can be used.
+ // writeRecordFields will validate that the field types match.
+ recordConsumer.startGroup();
+ writeRecordFields(schema, avroSchema, record);
+ recordConsumer.endGroup();
+ return;
+ }
+
+ boolean binarySchema = true;
+ ByteBuffer metadata = null;
+ ByteBuffer value = null;
+ // Extract the value and metadata binary.
+ for (int index = 0; index < avroFields.size(); index++) {
+ Schema.Field avroField = avroFields.get(index);
+ Schema fieldSchema = AvroSchemaConverter.getNonNull(avroField.schema());
+ if (!fieldSchema.getType().equals(Schema.Type.BYTES)) {
+ binarySchema = false;
+ break;
+ }
+ Type fieldType = fields.get(index);
+ if (fieldType.getName().equals("value")) {
+ Object valueObj = model.getField(record, avroField.name(), index);
+ Preconditions.checkArgument(
+ valueObj instanceof ByteBuffer,
+ "Expected ByteBuffer for value, but got " + valueObj.getClass());
+ value = (ByteBuffer) valueObj;
+ } else if (fieldType.getName().equals("metadata")) {
+ Object metadataObj = model.getField(record, avroField.name(), index);
+ Preconditions.checkArgument(
+ metadataObj instanceof ByteBuffer,
+ "Expected metadata to be a ByteBuffer, but got " +
metadataObj.getClass());
+ metadata = (ByteBuffer) metadataObj;
+ } else {
+ binarySchema = false;
+ break;
+ }
+ }
+
+ if (binarySchema) {
+ VariantValueWriter.write(recordConsumer, schema, new Variant(value,
metadata));
+ } else {
+ throw new RuntimeException("Invalid Avro schema for Variant logical
type: " + schema.getName());
+ }
}
private void writeRecordFields(GroupType schema, Schema avroSchema, Object
record) {
diff --git
a/parquet-avro/src/test/java/org/apache/parquet/avro/AvroTestUtil.java
b/parquet-avro/src/test/java/org/apache/parquet/avro/AvroTestUtil.java
index c71b6a774..514fb4748 100644
--- a/parquet-avro/src/test/java/org/apache/parquet/avro/AvroTestUtil.java
+++ b/parquet-avro/src/test/java/org/apache/parquet/avro/AvroTestUtil.java
@@ -33,6 +33,7 @@ import org.apache.hadoop.fs.Path;
import org.apache.parquet.hadoop.ParquetReader;
import org.apache.parquet.hadoop.ParquetWriter;
import org.apache.parquet.hadoop.util.HadoopInputFile;
+import org.apache.parquet.variant.Variant;
import org.junit.Assert;
import org.junit.rules.TemporaryFolder;
@@ -129,4 +130,36 @@ public class AvroTestUtil {
conf.setBoolean(name, value);
return conf;
}
+
+ /**
+ * Assert that to Variant values are logically equivalent.
+ * E.g. values in an object may be ordered differently in the binary.
+ */
+ static void assertEquivalent(Variant expected, Variant actual) {
+ Assert.assertEquals(expected.getType(), actual.getType());
+ switch (expected.getType()) {
+ case STRING:
+ // Short strings may use the compact or extended representation.
+ Assert.assertEquals(expected.getString(), actual.getString());
+ break;
+ case ARRAY:
+ Assert.assertEquals(expected.numArrayElements(),
actual.numArrayElements());
+ for (int i = 0; i < expected.numArrayElements(); ++i) {
+ assertEquivalent(expected.getElementAtIndex(i),
actual.getElementAtIndex(i));
+ }
+ break;
+ case OBJECT:
+ Assert.assertEquals(expected.numObjectElements(),
actual.numObjectElements());
+ for (int i = 0; i < expected.numObjectElements(); ++i) {
+ Variant.ObjectField expectedField = expected.getFieldAtIndex(i);
+ Variant.ObjectField actualField = actual.getFieldAtIndex(i);
+ Assert.assertEquals(expectedField.key, actualField.key);
+ assertEquivalent(expectedField.value, actualField.value);
+ }
+ break;
+ default:
+ // All other types have a single representation, and must be
bit-for-bit identical.
+ Assert.assertEquals(expected.getValueBuffer(),
actual.getValueBuffer());
+ }
+ }
}
diff --git
a/parquet-avro/src/test/java/org/apache/parquet/avro/TestReadVariant.java
b/parquet-avro/src/test/java/org/apache/parquet/avro/TestReadVariant.java
index 824b678f4..a90d4bbfc 100644
--- a/parquet-avro/src/test/java/org/apache/parquet/avro/TestReadVariant.java
+++ b/parquet-avro/src/test/java/org/apache/parquet/avro/TestReadVariant.java
@@ -2111,36 +2111,8 @@ public class TestReadVariant extends DirectWriterTest {
void assertEquivalent(ByteBuffer expectedMetadata, ByteBuffer expectedValue,
GenericRecord actual) {
assertEquals(expectedMetadata, (ByteBuffer) actual.get("metadata"));
assertEquals(expectedMetadata, (ByteBuffer) actual.get("metadata"));
- assertEquivalent(
+ AvroTestUtil.assertEquivalent(
new Variant(expectedValue, expectedMetadata),
new Variant(((ByteBuffer) actual.get("value")), expectedMetadata));
}
-
- void assertEquivalent(Variant expected, Variant actual) {
- assertEquals(expected.getType(), actual.getType());
- switch (expected.getType()) {
- case STRING:
- // Short strings may use the compact or extended representation.
- assertEquals(expected.getString(), actual.getString());
- break;
- case ARRAY:
- assertEquals(expected.numArrayElements(), actual.numArrayElements());
- for (int i = 0; i < expected.numArrayElements(); ++i) {
- assertEquivalent(expected.getElementAtIndex(i),
actual.getElementAtIndex(i));
- }
- break;
- case OBJECT:
- assertEquals(expected.numObjectElements(), actual.numObjectElements());
- for (int i = 0; i < expected.numObjectElements(); ++i) {
- Variant.ObjectField expectedField = expected.getFieldAtIndex(i);
- Variant.ObjectField actualField = actual.getFieldAtIndex(i);
- assertEquals(expectedField.key, actualField.key);
- assertEquivalent(expectedField.value, actualField.value);
- }
- break;
- default:
- // All other types have a single representation, and must be
bit-for-bit identical.
- assertEquals(expected.getValueBuffer(), actual.getValueBuffer());
- }
- }
}
diff --git
a/parquet-avro/src/test/java/org/apache/parquet/avro/TestWriteVariant.java
b/parquet-avro/src/test/java/org/apache/parquet/avro/TestWriteVariant.java
new file mode 100644
index 000000000..4853c46bb
--- /dev/null
+++ b/parquet-avro/src/test/java/org/apache/parquet/avro/TestWriteVariant.java
@@ -0,0 +1,612 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.parquet.avro;
+
+import static org.junit.Assert.assertEquals;
+
+import java.io.File;
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.UUID;
+import java.util.function.Consumer;
+import org.apache.avro.Schema;
+import org.apache.avro.generic.GenericData;
+import org.apache.avro.generic.GenericRecord;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.parquet.DirectWriterTest;
+import org.apache.parquet.hadoop.ParquetWriter;
+import org.apache.parquet.hadoop.api.WriteSupport;
+import org.apache.parquet.schema.GroupType;
+import org.apache.parquet.schema.LogicalTypeAnnotation;
+import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit;
+import org.apache.parquet.schema.MessageType;
+import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName;
+import org.apache.parquet.schema.Type;
+import org.apache.parquet.schema.Types;
+import org.apache.parquet.variant.ImmutableMetadata;
+import org.apache.parquet.variant.Variant;
+import org.apache.parquet.variant.VariantArrayBuilder;
+import org.apache.parquet.variant.VariantBuilder;
+import org.apache.parquet.variant.VariantObjectBuilder;
+import org.junit.Test;
+
+public class TestWriteVariant extends DirectWriterTest {
+
+ private static Variant fullVariant(Consumer<VariantBuilder> appendValue) {
+ VariantBuilder builder = new VariantBuilder();
+ appendValue.accept(builder);
+ return builder.build();
+ }
+
+ // Return only the byte[], which is usually all we want.
+ private static ByteBuffer variant(Consumer<VariantBuilder> appendValue) {
+ return fullVariant(appendValue).getValueBuffer();
+ }
+
+ // Returns a value based on building with fixed metadata.
+ private static ByteBuffer variant(ByteBuffer metadata,
Consumer<VariantBuilder> appendValue) {
+ VariantBuilder builder = new VariantBuilder(new
ImmutableMetadata(metadata));
+ appendValue.accept(builder);
+ return builder.build().getValueBuffer();
+ }
+
+ private static ByteBuffer variant(int val) {
+ return variant(b -> b.appendInt(val));
+ }
+
+ private static ByteBuffer variant(long val) {
+ return variant(b -> b.appendLong(val));
+ }
+
+ private static ByteBuffer variant(String s) {
+ return variant(b -> b.appendString(s));
+ }
+
+ private static final GroupType UNSHREDDED_GROUP =
Types.buildGroup(Type.Repetition.REQUIRED)
+ .as(LogicalTypeAnnotation.variantType((byte) 1))
+ .required(PrimitiveTypeName.BINARY)
+ .named("metadata")
+ .required(PrimitiveTypeName.BINARY)
+ .named("value")
+ .named("var");
+
+ private static MessageType parquetSchema(GroupType variantGroup) {
+ return Types.buildMessage()
+ .required(PrimitiveTypeName.INT32)
+ .named("id")
+ .addField(variantGroup)
+ .named("table");
+ }
+
+ private static final MessageType READ_SCHEMA =
parquetSchema(UNSHREDDED_GROUP);
+
+ private static final Schema VARIANT_SCHEMA = new
AvroSchemaConverter().convert(UNSHREDDED_GROUP);
+ private static final Schema SCHEMA = new
AvroSchemaConverter().convert(READ_SCHEMA);
+
+ private ByteBuffer TEST_METADATA;
+ private ByteBuffer TEST_OBJECT;
+ private ByteBuffer SIMILAR_OBJECT;
+ private ByteBuffer EMPTY_ARRAY;
+ private ByteBuffer STRING_ARRAY;
+ private ByteBuffer MIXED_ARRAY;
+ private ByteBuffer NESTED_ARRAY;
+ private ByteBuffer MIXED_NESTED_ARRAY;
+ private ByteBuffer OBJECT_IN_ARRAY;
+ private ByteBuffer MIXED_OBJECT_IN_ARRAY;
+ private ByteBuffer SIMILAR_OBJECT_IN_ARRAY;
+ private ByteBuffer EMPTY_OBJECT;
+ private ByteBuffer EMPTY_METADATA = fullVariant(b ->
b.appendNull()).getMetadataBuffer();
+ private Variant[] VARIANTS;
+
+ public TestWriteVariant() throws Exception {
+ TEST_METADATA = fullVariant(b -> {
+ VariantObjectBuilder ob = b.startObject();
+ ob.appendKey("a");
+ ob.appendNull();
+ ob.appendKey("b");
+ ob.appendNull();
+ ob.appendKey("c");
+ ob.appendNull();
+ ob.appendKey("d");
+ ob.appendNull();
+ ob.appendKey("e");
+ ob.appendNull();
+ b.endObject();
+ })
+ .getMetadataBuffer();
+
+ TEST_OBJECT = variant(TEST_METADATA, b -> {
+ VariantObjectBuilder ob = b.startObject();
+ ob.appendKey("a");
+ ob.appendNull();
+ ob.appendKey("d");
+ ob.appendString("iceberg");
+ b.endObject();
+ });
+
+ SIMILAR_OBJECT = variant(TEST_METADATA, b -> {
+ VariantObjectBuilder ob = b.startObject();
+ ob.appendKey("a");
+ ob.appendInt(123456789);
+ ob.appendKey("c");
+ ob.appendString("string");
+ b.endObject();
+ });
+
+ EMPTY_ARRAY = variant(b -> {
+ b.startArray();
+ b.endArray();
+ });
+
+ STRING_ARRAY = variant(b -> {
+ VariantArrayBuilder ab = b.startArray();
+ ab.appendString("parquet");
+ b.endArray();
+ });
+
+ MIXED_ARRAY = variant(b -> {
+ VariantArrayBuilder ab = b.startArray();
+ ab.appendString("parquet");
+ ab.appendString("string");
+ ab.appendInt(34);
+ b.endArray();
+ });
+
+ NESTED_ARRAY = variant(b -> {
+ VariantArrayBuilder ab = b.startArray();
+ VariantArrayBuilder inner1 = ab.startArray();
+ inner1.appendString("parquet");
+ inner1.appendString("string");
+ ab.endArray();
+ VariantArrayBuilder inner2 = ab.startArray();
+ inner2.appendString("parquet");
+ inner2.appendString("string");
+ ab.endArray();
+ b.endArray();
+ });
+
+ MIXED_NESTED_ARRAY = variant(b -> {
+ VariantArrayBuilder ab = b.startArray();
+ VariantArrayBuilder inner1 = ab.startArray();
+ inner1.appendString("parquet");
+ inner1.appendString("string");
+ inner1.appendInt(34);
+ ab.endArray();
+ VariantArrayBuilder inner2 = ab.startArray();
+ inner2.appendInt(34);
+ inner2.appendNull();
+ ab.endArray();
+ ab.startArray();
+ ab.endArray();
+ VariantArrayBuilder inner4 = ab.startArray();
+ inner4.appendString("parquet");
+ inner4.appendString("string");
+ inner4.appendInt(34);
+ ab.endArray();
+ b.endArray();
+ });
+
+ // The first array element defines the schema.
+ OBJECT_IN_ARRAY = variant(TEST_METADATA, b -> {
+ VariantArrayBuilder ab = b.startArray();
+ VariantObjectBuilder ob = ab.startObject();
+ ob.appendKey("a");
+ ob.appendNull();
+ ob.appendKey("d");
+ ob.appendString("iceberg");
+ ab.endObject();
+ ab.appendInt(123);
+ VariantObjectBuilder ob2 = ab.startObject();
+ ob2.appendKey("c");
+ ob2.appendString("hello");
+ ob2.appendKey("d");
+ ob2.appendDate(12345);
+ ab.endObject();
+ b.endArray();
+ });
+
+ MIXED_OBJECT_IN_ARRAY = variant(TEST_METADATA, b -> {
+ VariantArrayBuilder ab = b.startArray();
+ VariantObjectBuilder ob = ab.startObject();
+ ob.appendKey("a");
+ ob.appendNull();
+ ob.appendKey("d");
+ ob.appendString("iceberg");
+ ab.endObject();
+ ab.appendInt(123);
+ VariantObjectBuilder ob2 = ab.startObject();
+ ob2.appendKey("c");
+ ob2.appendString("hello");
+ ob2.appendKey("d");
+ ob2.appendDate(12345);
+ ab.endObject();
+ ab.appendString("parquet");
+ ab.appendInt(34);
+ b.endArray();
+ });
+
+ // Change one field name and one type in the first element to change the
schema.
+ SIMILAR_OBJECT_IN_ARRAY = variant(TEST_METADATA, b -> {
+ VariantArrayBuilder ab = b.startArray();
+ VariantObjectBuilder ob = ab.startObject();
+ ob.appendKey("c");
+ ob.appendString("iceberg");
+ ob.appendKey("a");
+ ob.appendString("parquet");
+ ab.endObject();
+ ab.appendInt(123);
+ VariantObjectBuilder ob2 = ab.startObject();
+ ob2.appendKey("c");
+ ob2.appendString("hello");
+ ob2.appendKey("d");
+ ob2.appendDate(12345);
+ ab.endObject();
+ b.endArray();
+ });
+
+ EMPTY_OBJECT = variant(TEST_METADATA, b -> {
+ b.startObject();
+ b.endObject();
+ });
+
+ VARIANTS = new Variant[] {
+ fullVariant(b -> b.appendNull()),
+ fullVariant(b -> b.appendBoolean(true)),
+ fullVariant(b -> b.appendBoolean(false)),
+ fullVariant(b -> b.appendByte((byte) 34)),
+ fullVariant(b -> b.appendByte((byte) -34)),
+ fullVariant(b -> b.appendShort((byte) 1234)),
+ fullVariant(b -> b.appendShort((byte) -1234)),
+ fullVariant(b -> b.appendInt(12345)),
+ fullVariant(b -> b.appendInt(-12345)),
+ fullVariant(b -> b.appendLong(9876543210L)),
+ fullVariant(b -> b.appendLong(-9876543210L)),
+ fullVariant(b -> b.appendFloat(10.11F)),
+ fullVariant(b -> b.appendFloat(-10.11F)),
+ fullVariant(b -> b.appendDouble(14.3D)),
+ fullVariant(b -> b.appendDouble(-14.3D)),
+ new Variant(EMPTY_OBJECT, EMPTY_METADATA),
+ new Variant(TEST_OBJECT, TEST_METADATA),
+ new Variant(SIMILAR_OBJECT, TEST_METADATA),
+ new Variant(EMPTY_ARRAY, EMPTY_METADATA),
+ new Variant(STRING_ARRAY, EMPTY_METADATA),
+ new Variant(MIXED_ARRAY, EMPTY_METADATA),
+ new Variant(NESTED_ARRAY, EMPTY_METADATA),
+ new Variant(MIXED_NESTED_ARRAY, EMPTY_METADATA),
+ new Variant(OBJECT_IN_ARRAY, TEST_METADATA),
+ new Variant(MIXED_OBJECT_IN_ARRAY, TEST_METADATA),
+ new Variant(SIMILAR_OBJECT_IN_ARRAY, TEST_METADATA),
+ fullVariant(b -> b.appendDate(12345)),
+ fullVariant(b -> b.appendDate(-12345)),
+ fullVariant(b -> b.appendTimestampTz(1234567890L)),
+ fullVariant(b -> b.appendTimestampTz(-1234567890L)),
+ fullVariant(b -> b.appendTimestampNtz(1234567890L)),
+ fullVariant(b -> b.appendTimestampNtz(-1234567890L)),
+ fullVariant(b -> b.appendDecimal(new BigDecimal("123456.789"))), //
decimal4
+ fullVariant(b -> b.appendDecimal(new BigDecimal("-123456.789"))), //
decimal4
+ fullVariant(b -> b.appendDecimal(new BigDecimal("123456.7"))), //
decimal4
+ fullVariant(b -> b.appendDecimal(new BigDecimal("-123456.7"))), //
decimal4
+ fullVariant(b -> b.appendDecimal(new
BigDecimal("123456789.987654321"))), // decimal8
+ fullVariant(b -> b.appendDecimal(new
BigDecimal("-123456789.987654321"))), // decimal8
+ fullVariant(b -> b.appendDecimal(new BigDecimal("123456789.9876543"))),
// decimal8
+ fullVariant(b -> b.appendDecimal(new BigDecimal("-123456789.9876543"))),
// decimal8
+ fullVariant(b -> b.appendDecimal(new
BigDecimal("9876543210.123456789"))), // decimal16
+ fullVariant(b -> b.appendDecimal(new
BigDecimal("-9876543210.123456789"))), // decimal16
+ fullVariant(b -> b.appendDecimal(new
BigDecimal("9876543210.12345678912345"))), // decimal16
+ fullVariant(b -> b.appendDecimal(new
BigDecimal("-9876543210.12345678912345"))), // decimal16
+ fullVariant(b -> b.appendBinary(ByteBuffer.wrap(new byte[] {0x0a, 0x0b,
0x0c, 0x0d}))),
+ fullVariant(b -> b.appendString("iceberg")),
+ fullVariant(b -> b.appendTime(1234567890)),
+ fullVariant(b -> b.appendTimestampNanosTz(1234567890L)),
+ fullVariant(b -> b.appendTimestampNanosTz(-1234567890L)),
+ fullVariant(b -> b.appendTimestampNanosNtz(1234567890L)),
+ fullVariant(b -> b.appendTimestampNanosNtz(-1234567890L)),
+ fullVariant(b ->
b.appendUUID(UUID.fromString("f24f9b64-81fa-49d1-b74e-8c09a6e31c56")))
+ };
+ }
+
+ /**
+ * Create a record containing a Variant value using the standard schema.
+ */
+ GenericRecord createRecord(int i, Variant v) {
+ GenericRecord vRecord = new GenericData.Record(VARIANT_SCHEMA);
+ vRecord.put(0, v.getMetadataBuffer());
+ vRecord.put(1, v.getValueBuffer());
+ GenericRecord record = new GenericData.Record(SCHEMA);
+ record.put(0, i);
+ record.put(1, vRecord);
+ return record;
+ }
+
+ // Tests in this file are based on Iceberg's TestVariantWriters suite.
+ @Test
+ public void testUnshreddedValues() throws IOException {
+ for (Variant v : VARIANTS) {
+ GenericRecord record = createRecord(1, v);
+ TestSchema testSchema = new TestSchema(READ_SCHEMA, READ_SCHEMA);
+
+ GenericRecord actual = writeAndRead(testSchema, record);
+
+ assertEquals(record.get(0), actual.get(0));
+ assertEquals(((GenericRecord) record.get(1)).get(0), ((GenericRecord)
actual.get(1)).get(0));
+ assertEquals(((GenericRecord) record.get(1)).get(1), ((GenericRecord)
actual.get(1)).get(1));
+ }
+ }
+
+ @Test
+ public void testShreddedValues() throws IOException {
+ for (Variant v : VARIANTS) {
+ GenericRecord record = createRecord(1, v);
+ MessageType writeSchema = shreddingSchema(v);
+ TestSchema testSchema = new TestSchema(writeSchema, READ_SCHEMA);
+
+ GenericRecord actual = writeAndRead(testSchema, record);
+ assertEquals(record.get(0), actual.get(0));
+ Variant actualV = new Variant((ByteBuffer) ((GenericRecord)
actual.get(1)).get(1), (ByteBuffer)
+ ((GenericRecord) actual.get(1)).get(0));
+ AvroTestUtil.assertEquivalent(v, actualV);
+ }
+ }
+
+ @Test
+ public void testMixedShredding() throws IOException {
+ List<GenericRecord> expected = new ArrayList<>();
+ for (int i = 0; i < VARIANTS.length; i++) {
+ expected.add(createRecord(i, VARIANTS[i]));
+ }
+
+ for (Variant valueForSchema : VARIANTS) {
+ MessageType writeSchema = shreddingSchema(valueForSchema);
+ TestSchema testSchema = new TestSchema(writeSchema, READ_SCHEMA);
+
+ List<GenericRecord> actual = writeAndRead(testSchema, expected);
+ assertEquals(actual.size(), expected.size());
+ for (int i = 0; i < expected.size(); i++) {
+ Variant actualV =
+ new Variant((ByteBuffer) ((GenericRecord)
actual.get(i).get(1)).get(1), (ByteBuffer)
+ ((GenericRecord) actual.get(i).get(1)).get(0));
+ AvroTestUtil.assertEquivalent(VARIANTS[i], actualV);
+ }
+ }
+ }
+
+ // Write schema contains the full shredding schema. Read schema should just
be a value/metadata pair.
+ private static class TestSchema {
+ MessageType writeSchema;
+ MessageType readSchema;
+
+ TestSchema(MessageType writeSchema, MessageType readSchema) {
+ this.writeSchema = writeSchema;
+ this.readSchema = readSchema;
+ }
+ }
+
+ /**
+ * This is a custom Parquet writer builder that injects a specific Parquet
schema and then uses
+ * the Avro object model. This ensures that the Parquet file's schema is
exactly what was passed.
+ */
+ private static class TestWriterBuilder extends
ParquetWriter.Builder<GenericRecord, TestWriterBuilder> {
+ private MessageType schema = null;
+
+ protected TestWriterBuilder(Path path) {
+ super(path);
+ }
+
+ TestWriterBuilder withFileType(MessageType schema) {
+ this.schema = schema;
+ return self();
+ }
+
+ @Override
+ protected TestWriterBuilder self() {
+ return this;
+ }
+
+ @Override
+ protected WriteSupport<GenericRecord> getWriteSupport(Configuration conf) {
+ return new AvroWriteSupport<>(schema, new
AvroSchemaConverter().convert(schema), GenericData.get());
+ }
+ }
+
+ GenericRecord writeAndRead(TestSchema testSchema, GenericRecord record)
throws IOException {
+ List<GenericRecord> result = writeAndRead(testSchema,
Arrays.asList(record));
+ assertEquals(result.size(), 1);
+ return result.get(0);
+ }
+
+ private List<GenericRecord> writeAndRead(TestSchema testSchema,
List<GenericRecord> records) throws IOException {
+ File tmp = File.createTempFile(getClass().getSimpleName(), ".tmp");
+ tmp.deleteOnExit();
+ tmp.delete();
+ Path path = new Path(tmp.getPath());
+
+ try (ParquetWriter<GenericRecord> writer =
+ new
TestWriterBuilder(path).withFileType(testSchema.writeSchema).build()) {
+ for (GenericRecord record : records) {
+ writer.write(record);
+ }
+ }
+
+ Configuration conf = new Configuration();
+ AvroReadSupport.setAvroReadSchema(conf, new
AvroSchemaConverter().convert(testSchema.readSchema));
+ AvroParquetReader<GenericRecord> reader = new AvroParquetReader(conf,
path);
+
+ ArrayList<GenericRecord> result = new ArrayList<>();
+ GenericRecord next = reader.read();
+ while (next != null) {
+ result.add(next);
+ next = reader.read();
+ }
+ return result;
+ }
+
+ /**
+ * Build a shredding schema that will perfectly shred the provided value.
+ */
+ private static MessageType shreddingSchema(Variant v) {
+ Type shreddedType = shreddedType(v);
+ Types.GroupBuilder<GroupType> partialType =
Types.buildGroup(Type.Repetition.OPTIONAL)
+ .as(LogicalTypeAnnotation.variantType((byte) 1))
+ .required(PrimitiveTypeName.BINARY)
+ .named("metadata")
+ .optional(PrimitiveTypeName.BINARY)
+ .named("value");
+ Type variantType;
+ if (shreddedType == null) {
+ variantType = partialType.named("var");
+ } else {
+ variantType = partialType.addField(shreddedType).named("var");
+ }
+ return Types.buildMessage()
+ .required(PrimitiveTypeName.INT32)
+ .named("id")
+ .addField(variantType)
+ .named("table");
+ }
+
+ private static GroupType shreddedGroup(Variant v, String name) {
+ Type shreddedType = shreddedType(v);
+ if (shreddedType == null) {
+ return Types.buildGroup(Type.Repetition.OPTIONAL)
+ .optional(PrimitiveTypeName.BINARY)
+ .named("value")
+ .named(name);
+ } else {
+ return Types.buildGroup(Type.Repetition.OPTIONAL)
+ .optional(PrimitiveTypeName.BINARY)
+ .named("value")
+ .addField(shreddedType)
+ .named(name);
+ }
+ }
+
+ /**
+ * @return A shredded type, or null if there is no valid shredded type.
+ */
+ private static Type shreddedType(Variant v) {
+ switch (v.getType()) {
+ case NULL:
+ return null;
+ case BOOLEAN:
+ return Types.optional(PrimitiveTypeName.BOOLEAN).named("typed_value");
+ case BYTE:
+ return Types.optional(PrimitiveTypeName.INT32)
+ .as(LogicalTypeAnnotation.intType(8))
+ .named("typed_value");
+ case SHORT:
+ return Types.optional(PrimitiveTypeName.INT32)
+ .as(LogicalTypeAnnotation.intType(16))
+ .named("typed_value");
+ case INT:
+ return Types.optional(PrimitiveTypeName.INT32).named("typed_value");
+ case LONG:
+ return Types.optional(PrimitiveTypeName.INT64).named("typed_value");
+ case FLOAT:
+ return Types.optional(PrimitiveTypeName.FLOAT).named("typed_value");
+ case DOUBLE:
+ return Types.optional(PrimitiveTypeName.DOUBLE).named("typed_value");
+ case DECIMAL4:
+ return Types.optional(PrimitiveTypeName.INT32)
+ .as(LogicalTypeAnnotation.decimalType(v.getDecimal().scale(), 9))
+ .named("typed_value");
+ case DECIMAL8:
+ return Types.optional(PrimitiveTypeName.INT64)
+ .as(LogicalTypeAnnotation.decimalType(v.getDecimal().scale(), 18))
+ .named("typed_value");
+ case DECIMAL16:
+ return Types.optional(PrimitiveTypeName.BINARY)
+ .as(LogicalTypeAnnotation.decimalType(v.getDecimal().scale(), 38))
+ .named("typed_value");
+ case DATE:
+ return Types.optional(PrimitiveTypeName.INT32)
+ .as(LogicalTypeAnnotation.dateType())
+ .named("typed_value");
+ case TIMESTAMP_TZ:
+ return Types.optional(PrimitiveTypeName.INT64)
+ .as(LogicalTypeAnnotation.timestampType(true, TimeUnit.MICROS))
+ .named("typed_value");
+ case TIMESTAMP_NTZ:
+ return Types.optional(PrimitiveTypeName.INT64)
+ .as(LogicalTypeAnnotation.timestampType(false, TimeUnit.MICROS))
+ .named("typed_value");
+ case BINARY:
+ return Types.optional(PrimitiveTypeName.BINARY).named("typed_value");
+ case STRING:
+ return Types.optional(PrimitiveTypeName.BINARY)
+ .as(LogicalTypeAnnotation.stringType())
+ .named("typed_value");
+ case TIME:
+ return Types.optional(PrimitiveTypeName.INT64)
+ .as(LogicalTypeAnnotation.timeType(false, TimeUnit.MICROS))
+ .named("typed_value");
+ case TIMESTAMP_NANOS_TZ:
+ return Types.optional(PrimitiveTypeName.INT64)
+ .as(LogicalTypeAnnotation.timestampType(true, TimeUnit.NANOS))
+ .named("typed_value");
+ case TIMESTAMP_NANOS_NTZ:
+ return Types.optional(PrimitiveTypeName.INT64)
+ .as(LogicalTypeAnnotation.timestampType(false, TimeUnit.NANOS))
+ .named("typed_value");
+ case UUID:
+ return Types.optional(PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY)
+ .as(LogicalTypeAnnotation.uuidType())
+ .named("typed_value");
+ case OBJECT:
+ return shreddedObjectType(v);
+ case ARRAY:
+ return shreddedArrayType(v);
+ default:
+ throw new UnsupportedOperationException("Unsupported shredding type: "
+ v.getType());
+ }
+ }
+
+ private static Type shreddedObjectType(Variant v) {
+ if (v.numObjectElements() == 0) {
+ // Parquet can't represent empty groups.
+ return null;
+ }
+ Types.GroupBuilder<GroupType> builder = Types.optionalGroup();
+ for (int i = 0; i < v.numObjectElements(); i++) {
+ Variant.ObjectField field = v.getFieldAtIndex(i);
+ Types.GroupBuilder<GroupType> fieldBuilder = Types.optionalGroup();
+ Type fieldType = shreddedGroup(field.value, field.key);
+ builder.addField(fieldType);
+ }
+ return builder.named("typed_value");
+ }
+
+ private static Type shreddedArrayType(Variant v) {
+ // Use the first element to determine the array element type
+ Variant firstElement;
+ if (v.numArrayElements() > 0) {
+ firstElement = v.getElementAtIndex(0);
+ } else {
+ // Use null as a dummy value, which will omit typed_value from the
schema.
+ firstElement = fullVariant(b -> b.appendNull());
+ }
+
+ Type elementType = shreddedGroup(firstElement, "element");
+ return
Types.optionalList().setElementType(elementType).named("typed_value");
+ }
+}
diff --git
a/parquet-variant/src/main/java/org/apache/parquet/variant/VariantBuilder.java
b/parquet-variant/src/main/java/org/apache/parquet/variant/VariantBuilder.java
index 7a0c1e016..1cf3aa3a3 100644
---
a/parquet-variant/src/main/java/org/apache/parquet/variant/VariantBuilder.java
+++
b/parquet-variant/src/main/java/org/apache/parquet/variant/VariantBuilder.java
@@ -96,7 +96,7 @@ public class VariantBuilder {
*/
public void appendEncodedValue(ByteBuffer value) {
onAppend();
- int size = value.remaining();
+ int size = VariantUtil.valueSize(value);
checkCapacity(size);
value.duplicate().get(writeBuffer, writePos, size);
writePos += size;
diff --git
a/parquet-variant/src/main/java/org/apache/parquet/variant/VariantUtil.java
b/parquet-variant/src/main/java/org/apache/parquet/variant/VariantUtil.java
index 67ca77d13..149d457ad 100644
--- a/parquet-variant/src/main/java/org/apache/parquet/variant/VariantUtil.java
+++ b/parquet-variant/src/main/java/org/apache/parquet/variant/VariantUtil.java
@@ -848,4 +848,74 @@ class VariantUtil {
}
return result;
}
+
+ /**
+ * Computes the actual size (in bytes) of the Variant value.
+ * @param value The Variant value binary
+ * @return The size (in bytes) of the Variant value, including the header
byte
+ */
+ public static int valueSize(ByteBuffer value) {
+ int pos = value.position();
+ int basicType = value.get(pos) & BASIC_TYPE_MASK;
+ switch (basicType) {
+ case SHORT_STR:
+ int stringSize = (value.get(pos) >> BASIC_TYPE_BITS) &
PRIMITIVE_TYPE_MASK;
+ return 1 + stringSize;
+ case OBJECT: {
+ VariantUtil.ObjectInfo info = VariantUtil.getObjectInfo(slice(value,
pos));
+ return info.dataStartOffset
+ + readUnsigned(
+ value,
+ pos + info.offsetStartOffset + info.numElements *
info.offsetSize,
+ info.offsetSize);
+ }
+ case ARRAY: {
+ VariantUtil.ArrayInfo info = VariantUtil.getArrayInfo(slice(value,
pos));
+ return info.dataStartOffset
+ + readUnsigned(
+ value,
+ pos + info.offsetStartOffset + info.numElements *
info.offsetSize,
+ info.offsetSize);
+ }
+ default: {
+ int typeInfo = (value.get(pos) >> BASIC_TYPE_BITS) &
PRIMITIVE_TYPE_MASK;
+ switch (typeInfo) {
+ case NULL:
+ case TRUE:
+ case FALSE:
+ return 1;
+ case INT8:
+ return 2;
+ case INT16:
+ return 3;
+ case INT32:
+ case DATE:
+ case FLOAT:
+ return 5;
+ case INT64:
+ case DOUBLE:
+ case TIMESTAMP_TZ:
+ case TIMESTAMP_NTZ:
+ case TIME:
+ case TIMESTAMP_NANOS_TZ:
+ case TIMESTAMP_NANOS_NTZ:
+ return 9;
+ case DECIMAL4:
+ return 6;
+ case DECIMAL8:
+ return 10;
+ case DECIMAL16:
+ return 18;
+ case BINARY:
+ case LONG_STR:
+ return 1 + U32_SIZE + readUnsigned(value, pos + 1, U32_SIZE);
+ case UUID:
+ return 1 + UUID_SIZE;
+ default:
+ throw new UnsupportedOperationException(
+ String.format("Unknown type in Variant. primitive type: %d",
typeInfo));
+ }
+ }
+ }
+ }
}
diff --git
a/parquet-variant/src/main/java/org/apache/parquet/variant/VariantValueWriter.java
b/parquet-variant/src/main/java/org/apache/parquet/variant/VariantValueWriter.java
new file mode 100644
index 000000000..34396e246
--- /dev/null
+++
b/parquet-variant/src/main/java/org/apache/parquet/variant/VariantValueWriter.java
@@ -0,0 +1,375 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.parquet.variant;
+
+import java.math.BigDecimal;
+import java.nio.ByteBuffer;
+import org.apache.parquet.Preconditions;
+import org.apache.parquet.io.api.Binary;
+import org.apache.parquet.io.api.RecordConsumer;
+import org.apache.parquet.schema.GroupType;
+import org.apache.parquet.schema.LogicalTypeAnnotation;
+import org.apache.parquet.schema.PrimitiveType;
+import org.apache.parquet.schema.Type;
+
+/**
+ * Class to write Variant values to a shredded schema.
+ */
+public class VariantValueWriter {
+ private static final String LIST_REPEATED_NAME = "list";
+ private static final String LIST_ELEMENT_NAME = "element";
+
+ private final ByteBuffer metadataBuffer;
+ private final RecordConsumer recordConsumer;
+
+ // We defer initializing the ImmutableMetata until it's needed. There is a
construction cost to deserialize the
+ // metadata binary into a Map, and if all object fields are shredded into
typed_value, it will never be used.
+ private ImmutableMetadata metadata = null;
+
+ VariantValueWriter(RecordConsumer recordConsumer, ByteBuffer metadata) {
+ this.recordConsumer = recordConsumer;
+ this.metadataBuffer = metadata;
+ }
+
+ Metadata getMetadata() {
+ if (metadata == null) {
+ metadata = new ImmutableMetadata(metadataBuffer);
+ }
+ return metadata;
+ }
+
+ /**
+ * Write a Variant value to a shredded schema.
+ */
+ public static void write(RecordConsumer recordConsumer, GroupType schema,
Variant value) {
+ recordConsumer.startGroup();
+ int metadataIndex = schema.getFieldIndex("metadata");
+ recordConsumer.startField("metadata", metadataIndex);
+
recordConsumer.addBinary(Binary.fromConstantByteBuffer(value.getMetadataBuffer()));
+ recordConsumer.endField("metadata", metadataIndex);
+ VariantValueWriter writer = new VariantValueWriter(recordConsumer,
value.getMetadataBuffer());
+ writer.write(schema, value);
+ recordConsumer.endGroup();
+ }
+
+ /**
+ * Write a Variant value to a shredded schema. The caller is responsible for
calling startGroup()
+ * and endGroup(), and writing metadata if this is the top level of the
Variant group.
+ */
+ void write(GroupType schema, Variant value) {
+ Type typedValueField = null;
+ if (schema.containsField("typed_value")) {
+ typedValueField = schema.getType("typed_value");
+ }
+
+ Variant.Type variantType = value.getType();
+
+ // Handle typed_value if present
+ if (isTypeCompatible(variantType, typedValueField, value)) {
+ int typedValueIdx = schema.getFieldIndex("typed_value");
+ recordConsumer.startField("typed_value", typedValueIdx);
+ ByteBuffer residual = null;
+ if (typedValueField.isPrimitive()) {
+ writeScalarValue(value);
+ } else if (typedValueField.getLogicalTypeAnnotation()
+ instanceof LogicalTypeAnnotation.ListLogicalTypeAnnotation) {
+ writeArrayValue(value, typedValueField.asGroupType());
+ } else {
+ residual = writeObjectValue(value, typedValueField.asGroupType());
+ }
+ recordConsumer.endField("typed_value", typedValueIdx);
+
+ if (residual != null) {
+ int valueIdx = schema.getFieldIndex("value");
+ recordConsumer.startField("value", valueIdx);
+ recordConsumer.addBinary(Binary.fromConstantByteBuffer(residual));
+ recordConsumer.endField("value", valueIdx);
+ }
+ } else {
+ int valueIdx = schema.getFieldIndex("value");
+ recordConsumer.startField("value", valueIdx);
+
recordConsumer.addBinary(Binary.fromReusedByteBuffer(value.getValueBuffer()));
+ recordConsumer.endField("value", valueIdx);
+ }
+ }
+
+ // Return true if the logical type is a decimal with the same scale as the
provided value, with enough
+ // precision to hold the value. The provided value must be a decimal.
+ private static boolean compatibleDecimalType(Variant value,
LogicalTypeAnnotation logicalType) {
+ if (!(logicalType instanceof
LogicalTypeAnnotation.DecimalLogicalTypeAnnotation)) {
+ return false;
+ }
+ LogicalTypeAnnotation.DecimalLogicalTypeAnnotation decimalType =
+ (LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) logicalType;
+
+ BigDecimal decimal = value.getDecimal();
+ return decimal.scale() == decimalType.getScale() && decimal.precision() <=
decimalType.getPrecision();
+ }
+
+ private static boolean isTypeCompatible(Variant.Type variantType, Type
typedValueField, Variant value) {
+ if (typedValueField == null) {
+ return false;
+ }
+ if (typedValueField.isPrimitive()) {
+ PrimitiveType primitiveType = typedValueField.asPrimitiveType();
+ LogicalTypeAnnotation logicalType =
primitiveType.getLogicalTypeAnnotation();
+ PrimitiveType.PrimitiveTypeName primitiveTypeName =
primitiveType.getPrimitiveTypeName();
+
+ switch (variantType) {
+ case BOOLEAN:
+ return primitiveTypeName == PrimitiveType.PrimitiveTypeName.BOOLEAN;
+ case BYTE:
+ return primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT32
+ && logicalType instanceof
LogicalTypeAnnotation.IntLogicalTypeAnnotation
+ && ((LogicalTypeAnnotation.IntLogicalTypeAnnotation)
logicalType).isSigned()
+ && ((LogicalTypeAnnotation.IntLogicalTypeAnnotation)
logicalType).getBitWidth() == 8;
+ case SHORT:
+ return primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT32
+ && logicalType instanceof
LogicalTypeAnnotation.IntLogicalTypeAnnotation
+ && ((LogicalTypeAnnotation.IntLogicalTypeAnnotation)
logicalType).isSigned()
+ && ((LogicalTypeAnnotation.IntLogicalTypeAnnotation)
logicalType).getBitWidth() == 16;
+ case INT:
+ return primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT32
+ && (logicalType == null
+ || (logicalType instanceof
LogicalTypeAnnotation.IntLogicalTypeAnnotation
+ && ((LogicalTypeAnnotation.IntLogicalTypeAnnotation)
logicalType).isSigned()
+ && ((LogicalTypeAnnotation.IntLogicalTypeAnnotation)
logicalType)
+ .getBitWidth()
+ == 32));
+ case LONG:
+ return primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT64
+ && (logicalType == null
+ || (logicalType instanceof
LogicalTypeAnnotation.IntLogicalTypeAnnotation
+ && ((LogicalTypeAnnotation.IntLogicalTypeAnnotation)
logicalType).isSigned()
+ && ((LogicalTypeAnnotation.IntLogicalTypeAnnotation)
logicalType)
+ .getBitWidth()
+ == 64));
+ case FLOAT:
+ return primitiveTypeName == PrimitiveType.PrimitiveTypeName.FLOAT;
+ case DOUBLE:
+ return primitiveTypeName == PrimitiveType.PrimitiveTypeName.DOUBLE;
+ case DECIMAL4:
+ return primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT32
+ && compatibleDecimalType(value, logicalType);
+ case DECIMAL8:
+ return primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT64
+ && compatibleDecimalType(value, logicalType);
+ case DECIMAL16:
+ return primitiveTypeName == PrimitiveType.PrimitiveTypeName.BINARY
+ && compatibleDecimalType(value, logicalType);
+ case DATE:
+ return primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT32
+ && logicalType instanceof
LogicalTypeAnnotation.DateLogicalTypeAnnotation;
+ case TIME:
+ return primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT64
+ && logicalType instanceof
LogicalTypeAnnotation.TimeLogicalTypeAnnotation;
+ case TIMESTAMP_NTZ:
+ case TIMESTAMP_NANOS_NTZ:
+ case TIMESTAMP_TZ:
+ case TIMESTAMP_NANOS_TZ:
+ if (primitiveTypeName == PrimitiveType.PrimitiveTypeName.INT64
+ && logicalType instanceof
LogicalTypeAnnotation.TimestampLogicalTypeAnnotation) {
+ LogicalTypeAnnotation.TimestampLogicalTypeAnnotation annotation =
+ (LogicalTypeAnnotation.TimestampLogicalTypeAnnotation)
logicalType;
+ boolean micros = annotation.getUnit() ==
LogicalTypeAnnotation.TimeUnit.MICROS;
+ boolean nanos = annotation.getUnit() ==
LogicalTypeAnnotation.TimeUnit.NANOS;
+ boolean adjustedToUTC = annotation.isAdjustedToUTC();
+ return (variantType == Variant.Type.TIMESTAMP_TZ && micros &&
adjustedToUTC)
+ || (variantType == Variant.Type.TIMESTAMP_NTZ && micros &&
!adjustedToUTC)
+ || (variantType == Variant.Type.TIMESTAMP_NANOS_TZ && nanos &&
adjustedToUTC)
+ || (variantType == Variant.Type.TIMESTAMP_NANOS_NTZ && nanos
&& !adjustedToUTC);
+ } else {
+ return false;
+ }
+ case STRING:
+ return primitiveTypeName == PrimitiveType.PrimitiveTypeName.BINARY
+ && logicalType instanceof
LogicalTypeAnnotation.StringLogicalTypeAnnotation;
+ case BINARY:
+ return primitiveTypeName == PrimitiveType.PrimitiveTypeName.BINARY
&& logicalType == null;
+ default:
+ return false;
+ }
+ } else if (typedValueField.getLogicalTypeAnnotation()
+ instanceof LogicalTypeAnnotation.ListLogicalTypeAnnotation) {
+ return variantType == Variant.Type.ARRAY;
+ } else {
+ return variantType == Variant.Type.OBJECT;
+ }
+ }
+
+ private void writeScalarValue(Variant variant) {
+ switch (variant.getType()) {
+ case BOOLEAN:
+ recordConsumer.addBoolean(variant.getBoolean());
+ break;
+ case BYTE:
+ recordConsumer.addInteger(variant.getByte());
+ break;
+ case SHORT:
+ recordConsumer.addInteger(variant.getShort());
+ break;
+ case INT:
+ recordConsumer.addInteger(variant.getInt());
+ break;
+ case LONG:
+ recordConsumer.addLong(variant.getLong());
+ break;
+ case FLOAT:
+ recordConsumer.addFloat(variant.getFloat());
+ break;
+ case DOUBLE:
+ recordConsumer.addDouble(variant.getDouble());
+ break;
+ case DECIMAL4:
+
recordConsumer.addInteger(variant.getDecimal().unscaledValue().intValue());
+ break;
+ case DECIMAL8:
+
recordConsumer.addLong(variant.getDecimal().unscaledValue().longValue());
+ break;
+ case DECIMAL16:
+ recordConsumer.addBinary(Binary.fromConstantByteArray(
+ variant.getDecimal().unscaledValue().toByteArray()));
+ break;
+ case DATE:
+ recordConsumer.addInteger(variant.getInt());
+ break;
+ case TIME:
+ recordConsumer.addLong(variant.getLong());
+ break;
+ case TIMESTAMP_TZ:
+ recordConsumer.addLong(variant.getLong());
+ break;
+ case TIMESTAMP_NTZ:
+ recordConsumer.addLong(variant.getLong());
+ break;
+ case TIMESTAMP_NANOS_TZ:
+ recordConsumer.addLong(variant.getLong());
+ break;
+ case TIMESTAMP_NANOS_NTZ:
+ recordConsumer.addLong(variant.getLong());
+ break;
+ case STRING:
+ recordConsumer.addBinary(Binary.fromString(variant.getString()));
+ break;
+ case BINARY:
+
recordConsumer.addBinary(Binary.fromReusedByteBuffer(variant.getBinary()));
+ break;
+ default:
+ throw new IllegalArgumentException("Unsupported scalar type: " +
variant.getType());
+ }
+ }
+
+ private void writeArrayValue(Variant variant, GroupType arrayType) {
+ Preconditions.checkArgument(
+ variant.getType() == Variant.Type.ARRAY,
+ "Cannot write variant type " + variant.getType() + " as array");
+
+ // Validate that it's a 3-level array.
+ if (arrayType.getFieldCount() != 1
+ || arrayType.getRepetition() == Type.Repetition.REPEATED
+ || arrayType.getType(0).isPrimitive()
+ || !arrayType.getFieldName(0).equals(LIST_REPEATED_NAME)) {
+ throw new IllegalArgumentException("Variant list must be a three-level
list structure: " + arrayType);
+ }
+
+ // Get the element type from the array schema
+ GroupType repeatedType = arrayType.getType(0).asGroupType();
+
+ if (repeatedType.getFieldCount() != 1
+ || repeatedType.getRepetition() != Type.Repetition.REPEATED
+ || repeatedType.getType(0).isPrimitive()
+ || !repeatedType.getFieldName(0).equals(LIST_ELEMENT_NAME)) {
+ throw new IllegalArgumentException("Variant list must be a three-level
list structure: " + arrayType);
+ }
+
+ GroupType elementType = repeatedType.getType(0).asGroupType();
+
+ // List field, annotated as LIST
+ recordConsumer.startGroup();
+ int numElements = variant.numArrayElements();
+ // Can only call startField if there is at least one element.
+ if (numElements > 0) {
+ recordConsumer.startField(LIST_REPEATED_NAME, 0);
+ // Write each array element
+ for (int i = 0; i < numElements; i++) {
+ // Repeated group.
+ recordConsumer.startGroup();
+ recordConsumer.startField(LIST_ELEMENT_NAME, 0);
+
+ // Element group. Can never be null for shredded Variant.
+ recordConsumer.startGroup();
+ write(elementType, variant.getElementAtIndex(i));
+ recordConsumer.endGroup();
+
+ recordConsumer.endField(LIST_ELEMENT_NAME, 0);
+ recordConsumer.endGroup();
+ }
+ recordConsumer.endField(LIST_REPEATED_NAME, 0);
+ }
+ recordConsumer.endGroup();
+ }
+
+ /**
+ * Write an object to typed_value
+ *
+ * @return the residual value that must be written to the value column, or
null if all values were written
+ * to typed_value.
+ */
+ private ByteBuffer writeObjectValue(Variant variant, GroupType objectType) {
+ Preconditions.checkArgument(
+ variant.getType() == Variant.Type.OBJECT,
+ "Cannot write variant type " + variant.getType() + " as object");
+
+ VariantBuilder residualBuilder = null;
+ // The residualBuilder, if created, is always a single object. This is
that object's builder.
+ VariantObjectBuilder objectBuilder = null;
+
+ // Write each object field.
+ recordConsumer.startGroup();
+ for (int i = 0; i < variant.numObjectElements(); i++) {
+ Variant.ObjectField field = variant.getFieldAtIndex(i);
+
+ if (objectType.containsField(field.key)) {
+ int fieldIndex = objectType.getFieldIndex(field.key);
+ Type fieldType = objectType.getType(fieldIndex);
+
+ recordConsumer.startField(field.key, fieldIndex);
+ recordConsumer.startGroup();
+ write(fieldType.asGroupType(), field.value);
+ recordConsumer.endGroup();
+ recordConsumer.endField(field.key,
objectType.getFieldIndex(field.key));
+ } else {
+ if (residualBuilder == null) {
+ residualBuilder = new VariantBuilder(getMetadata());
+ objectBuilder = residualBuilder.startObject();
+ }
+ objectBuilder.appendKey(field.key);
+ objectBuilder.appendEncodedValue(field.value.getValueBuffer());
+ }
+ }
+ recordConsumer.endGroup();
+
+ if (residualBuilder != null) {
+ residualBuilder.endObject();
+ return residualBuilder.build().getValueBuffer();
+ } else {
+ return null;
+ }
+ }
+}