This is an automated email from the ASF dual-hosted git repository. etudenhoefner pushed a commit to branch spark-uuid-read-write-support-3.4 in repository https://gitbox.apache.org/repos/asf/iceberg.git
commit 07ddc3c777b7dccda24098a9918e3746a27bc500 Author: Eduard Tudenhoefner <[email protected]> AuthorDate: Fri Apr 21 09:18:32 2023 +0200 Spark: Add read/write support for UUIDs --- .../apache/iceberg/spark/data/SparkAvroWriter.java | 2 +- .../apache/iceberg/spark/data/SparkOrcReader.java | 3 ++ .../iceberg/spark/data/SparkOrcValueReaders.java | 32 +++++++++++++++++++ .../iceberg/spark/data/SparkOrcValueWriters.java | 17 ++++++++++ .../apache/iceberg/spark/data/SparkOrcWriter.java | 11 ++++++- .../iceberg/spark/data/SparkParquetReaders.java | 17 ++++++++++ .../iceberg/spark/data/SparkParquetWriters.java | 36 ++++++++++++++++++++++ .../data/vectorized/VectorizedSparkOrcReaders.java | 5 ++- .../apache/iceberg/spark/data/AvroDataTest.java | 2 +- .../org/apache/iceberg/spark/data/RandomData.java | 2 ++ .../iceberg/spark/data/TestSparkParquetWriter.java | 2 +- 11 files changed, 124 insertions(+), 5 deletions(-) diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java index 15465568c2..04dfd46a18 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java @@ -126,7 +126,7 @@ public class SparkAvroWriter implements MetricsAwareDatumWriter<InternalRow> { return SparkValueWriters.decimal(decimal.getPrecision(), decimal.getScale()); case "uuid": - return ValueWriters.uuids(); + return SparkValueWriters.uuids(); default: throw new IllegalArgumentException("Unsupported logical type: " + logicalType); diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java index 78db137054..c20be44f67 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java @@ -123,6 +123,9 @@ public class SparkOrcReader implements OrcRowReader<InternalRow> { case STRING: return SparkOrcValueReaders.utf8String(); case BINARY: + if (Type.TypeID.UUID == iPrimitive.typeId()) { + return SparkOrcValueReaders.uuids(); + } return OrcValueReaders.bytes(); default: throw new IllegalArgumentException("Unhandled type " + primitive); diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java index 9e9b3e53bb..2bc5ef96a3 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java @@ -19,6 +19,8 @@ package org.apache.iceberg.spark.data; import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.List; import java.util.Map; import org.apache.iceberg.orc.OrcValueReader; @@ -26,6 +28,7 @@ import org.apache.iceberg.orc.OrcValueReaders; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.UUIDUtil; import org.apache.orc.storage.ql.exec.vector.BytesColumnVector; import org.apache.orc.storage.ql.exec.vector.ColumnVector; import org.apache.orc.storage.ql.exec.vector.DecimalColumnVector; @@ -49,6 +52,10 @@ public class SparkOrcValueReaders { return StringReader.INSTANCE; } + public static OrcValueReader<UTF8String> uuids() { + return UUIDReader.INSTANCE; + } + public static OrcValueReader<Long> timestampTzs() { return TimestampTzReader.INSTANCE; } @@ -170,6 +177,31 @@ public class SparkOrcValueReaders { } } + private static class UUIDReader implements OrcValueReader<UTF8String> { + private static final ThreadLocal<ByteBuffer> BUFFER = + ThreadLocal.withInitial( + () -> { + ByteBuffer buffer = ByteBuffer.allocate(16); + buffer.order(ByteOrder.BIG_ENDIAN); + return buffer; + }); + + private static final UUIDReader INSTANCE = new UUIDReader(); + + private UUIDReader() {} + + @Override + public UTF8String nonNullRead(ColumnVector vector, int row) { + BytesColumnVector bytesVector = (BytesColumnVector) vector; + ByteBuffer buffer = BUFFER.get(); + buffer.rewind(); + buffer.put(bytesVector.vector[row], bytesVector.start[row], bytesVector.length[row]); + buffer.rewind(); + + return UTF8String.fromString(UUIDUtil.convert(buffer).toString()); + } + } + private static class TimestampTzReader implements OrcValueReader<Long> { private static final TimestampTzReader INSTANCE = new TimestampTzReader(); diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java index 780090f991..9a4f1b5b48 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java @@ -18,10 +18,13 @@ */ package org.apache.iceberg.spark.data; +import java.nio.ByteBuffer; import java.util.List; +import java.util.UUID; import java.util.stream.Stream; import org.apache.iceberg.FieldMetrics; import org.apache.iceberg.orc.OrcValueWriter; +import org.apache.iceberg.util.UUIDUtil; import org.apache.orc.TypeDescription; import org.apache.orc.storage.common.type.HiveDecimal; import org.apache.orc.storage.ql.exec.vector.BytesColumnVector; @@ -42,6 +45,10 @@ class SparkOrcValueWriters { return StringWriter.INSTANCE; } + static OrcValueWriter<?> uuids() { + return UUIDWriter.INSTANCE; + } + static OrcValueWriter<?> timestampTz() { return TimestampTzWriter.INSTANCE; } @@ -73,6 +80,16 @@ class SparkOrcValueWriters { } } + private static class UUIDWriter implements OrcValueWriter<UTF8String> { + private static final UUIDWriter INSTANCE = new UUIDWriter(); + + @Override + public void nonNullWrite(int rowId, UTF8String data, ColumnVector output) { + ByteBuffer buffer = UUIDUtil.convertToByteBuffer(UUID.fromString(data.toString())); + ((BytesColumnVector) output).setRef(rowId, buffer.array(), 0, buffer.array().length); + } + } + private static class TimestampTzWriter implements OrcValueWriter<Long> { private static final TimestampTzWriter INSTANCE = new TimestampTzWriter(); diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java index 60868b8700..c5477fac08 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java @@ -111,6 +111,9 @@ public class SparkOrcWriter implements OrcRowWriter<InternalRow> { case DOUBLE: return GenericOrcWriters.doubles(ORCSchemaUtil.fieldId(primitive)); case BINARY: + if (Type.TypeID.UUID == iPrimitive.typeId()) { + return SparkOrcValueWriters.uuids(); + } return GenericOrcWriters.byteArrays(); case STRING: case CHAR: @@ -173,7 +176,13 @@ public class SparkOrcWriter implements OrcRowWriter<InternalRow> { fieldGetter = SpecializedGetters::getDouble; break; case BINARY: - fieldGetter = SpecializedGetters::getBinary; + if (ORCSchemaUtil.BinaryType.UUID + .toString() + .equals(fieldType.getAttributeValue(ORCSchemaUtil.ICEBERG_BINARY_TYPE_ATTRIBUTE))) { + fieldGetter = SpecializedGetters::getUTF8String; + } else { + fieldGetter = SpecializedGetters::getBinary; + } // getBinary always makes a copy, so we don't need to worry about it // being changed behind our back. break; diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java index 59f81de6ae..af16d9bbc2 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java @@ -46,6 +46,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.types.Type.TypeID; import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.UUIDUtil; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.io.api.Binary; import org.apache.parquet.schema.GroupType; @@ -232,6 +233,7 @@ public class SparkParquetReaders { } @Override + @SuppressWarnings("checkstyle:CyclomaticComplexity") public ParquetValueReader<?> primitive( org.apache.iceberg.types.Type.PrimitiveType expected, PrimitiveType primitive) { ColumnDescriptor desc = type.getColumnDescription(currentPath()); @@ -282,6 +284,9 @@ public class SparkParquetReaders { switch (primitive.getPrimitiveTypeName()) { case FIXED_LEN_BYTE_ARRAY: case BINARY: + if (expected != null && expected.typeId() == TypeID.UUID) { + return new UUIDReader(desc); + } return new ParquetValueReaders.ByteArrayReader(desc); case INT32: if (expected != null && expected.typeId() == TypeID.LONG) { @@ -413,6 +418,18 @@ public class SparkParquetReaders { } } + private static class UUIDReader extends PrimitiveReader<UTF8String> { + UUIDReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + @SuppressWarnings("ByteBufferBackingArray") + public UTF8String read(UTF8String ignored) { + return UTF8String.fromString(UUIDUtil.convert(column.nextBinary().toByteBuffer()).toString()); + } + } + private static class ArrayReader<E> extends RepeatedReader<ArrayData, ReusableArrayData, E> { private int readPos = 0; private int writePos = 0; diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java index 3637fa4a26..c1abec96cd 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java @@ -18,10 +18,13 @@ */ package org.apache.iceberg.spark.data; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; +import java.util.UUID; import org.apache.iceberg.parquet.ParquetValueReaders.ReusableEntry; import org.apache.iceberg.parquet.ParquetValueWriter; import org.apache.iceberg.parquet.ParquetValueWriters; @@ -35,6 +38,7 @@ import org.apache.iceberg.util.DecimalUtil; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.io.api.Binary; import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.PrimitiveType; @@ -176,6 +180,9 @@ public class SparkParquetWriters { switch (primitive.getPrimitiveTypeName()) { case FIXED_LEN_BYTE_ARRAY: case BINARY: + if (LogicalTypeAnnotation.uuidType().equals(primitive.getLogicalTypeAnnotation())) { + return uuids(desc); + } return byteArrays(desc); case BOOLEAN: return ParquetValueWriters.booleans(desc); @@ -316,6 +323,35 @@ public class SparkParquetWriters { } } + private static PrimitiveWriter<UTF8String> uuids(ColumnDescriptor desc) { + return new UUIDWriter(desc); + } + + private static class UUIDWriter extends PrimitiveWriter<UTF8String> { + private static final ThreadLocal<ByteBuffer> BUFFER = + ThreadLocal.withInitial( + () -> { + ByteBuffer buffer = ByteBuffer.allocate(16); + buffer.order(ByteOrder.BIG_ENDIAN); + return buffer; + }); + + private UUIDWriter(ColumnDescriptor desc) { + super(desc); + } + + @Override + public void write(int repetitionLevel, UTF8String string) { + UUID uuid = UUID.fromString(string.toString()); + ByteBuffer buffer = BUFFER.get(); + buffer.rewind(); + buffer.putLong(uuid.getMostSignificantBits()); + buffer.putLong(uuid.getLeastSignificantBits()); + buffer.rewind(); + column.writeBinary(repetitionLevel, Binary.fromReusedByteBuffer(buffer)); + } + } + private static class ByteArrayWriter extends PrimitiveWriter<byte[]> { private ByteArrayWriter(ColumnDescriptor desc) { super(desc); diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java index b2d8bd14be..c030311232 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java @@ -155,7 +155,10 @@ public class VectorizedSparkOrcReaders { primitiveValueReader = SparkOrcValueReaders.utf8String(); break; case BINARY: - primitiveValueReader = OrcValueReaders.bytes(); + primitiveValueReader = + Type.TypeID.UUID == iPrimitive.typeId() + ? SparkOrcValueReaders.uuids() + : OrcValueReaders.bytes(); break; default: throw new IllegalArgumentException("Unhandled type " + primitive); diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java index 5fd137c536..db0d7336f1 100644 --- a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java @@ -56,7 +56,7 @@ public abstract class AvroDataTest { optional(107, "date", Types.DateType.get()), required(108, "ts", Types.TimestampType.withZone()), required(110, "s", Types.StringType.get()), - // required(111, "uuid", Types.UUIDType.get()), + required(111, "uuid", Types.UUIDType.get()), required(112, "fixed", Types.FixedType.ofLength(7)), optional(113, "bytes", Types.BinaryType.get()), required(114, "dec_9_0", Types.DecimalType.of(9, 0)), // int encoded diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java index 1c95df8ced..478afcf09a 100644 --- a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java @@ -329,6 +329,8 @@ public class RandomData { return UTF8String.fromString((String) obj); case DECIMAL: return Decimal.apply((BigDecimal) obj); + case UUID: + return UTF8String.fromString(UUID.nameUUIDFromBytes((byte[]) obj).toString()); default: return obj; } diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java index 261fb8838a..467d8a27a2 100644 --- a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java @@ -79,7 +79,7 @@ public class TestSparkParquetWriter { Types.StringType.get(), Types.StructType.of( optional(22, "jumpy", Types.DoubleType.get()), - required(23, "koala", Types.IntegerType.get()), + required(23, "koala", Types.UUIDType.get()), required(24, "couch rope", Types.IntegerType.get())))), optional(2, "slide", Types.StringType.get()));
