This is an automated email from the ASF dual-hosted git repository.
shangxinli pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/parquet-mr.git
The following commit(s) were added to refs/heads/master by this push:
new c8487c762 PARQUET-2042: Add support for unwrapping common Protobuf
wrappers and logical Timestamps, Date, TimeOfDay (#900)
c8487c762 is described below
commit c8487c762c08cfdb3c713538784e868cd47a7ac8
Author: Michael Wong <[email protected]>
AuthorDate: Mon Nov 27 01:39:11 2023 +0000
PARQUET-2042: Add support for unwrapping common Protobuf wrappers and
logical Timestamps, Date, TimeOfDay (#900)
Co-authored-by: Michael Wong <[email protected]>
---
parquet-protobuf/pom.xml | 12 +
.../parquet/proto/ProtoMessageConverter.java | 291 ++++++++++++++++++++-
.../apache/parquet/proto/ProtoSchemaConverter.java | 93 ++++++-
.../apache/parquet/proto/ProtoWriteSupport.java | 146 +++++++++++
.../parquet/proto/ProtoInputOutputFormatTest.java | 51 +++-
.../parquet/proto/ProtoSchemaConverterTest.java | 98 +++++--
.../parquet/proto/ProtoWriteSupportTest.java | 203 +++++++++++++-
.../src/test/resources/TestProto3.proto | 23 ++
8 files changed, 876 insertions(+), 41 deletions(-)
diff --git a/parquet-protobuf/pom.xml b/parquet-protobuf/pom.xml
index ddf634a77..9f604372b 100644
--- a/parquet-protobuf/pom.xml
+++ b/parquet-protobuf/pom.xml
@@ -32,6 +32,7 @@
<properties>
<elephant-bird.version>4.4</elephant-bird.version>
<protobuf.version>3.25.1</protobuf.version>
+ <common-protos.version>2.28.0</common-protos.version> <!-- make sure it's
compatible with protobuf.version -->
<truth-proto-extension.version>1.1.5</truth-proto-extension.version>
</properties>
@@ -67,6 +68,16 @@
<artifactId>protobuf-java</artifactId>
<version>${protobuf.version}</version>
</dependency>
+ <dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java-util</artifactId>
+ <version>${protobuf.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.api.grpc</groupId>
+ <artifactId>proto-google-common-protos</artifactId>
+ <version>${common-protos.version}</version>
+ </dependency>
<dependency>
<groupId>org.apache.parquet</groupId>
<artifactId>parquet-common</artifactId>
@@ -191,6 +202,7 @@
</goals>
<configuration>
<protocArtifact>com.google.protobuf:protoc:${protobuf.version}</protocArtifact>
+ <includeMavenTypes>direct</includeMavenTypes>
<addSources>test</addSources>
<addProtoSources>all</addProtoSources>
<includeMavenTypes>direct</includeMavenTypes>
diff --git
a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java
b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java
index da51788f2..5c17af6fe 100644
---
a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java
+++
b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java
@@ -18,10 +18,21 @@
*/
package org.apache.parquet.proto;
+import com.google.protobuf.BoolValue;
import com.google.protobuf.ByteString;
+import com.google.protobuf.BytesValue;
import com.google.protobuf.DescriptorProtos;
import com.google.protobuf.Descriptors;
+import com.google.protobuf.Descriptors.Descriptor;
+import com.google.protobuf.DoubleValue;
+import com.google.protobuf.FloatValue;
+import com.google.protobuf.Int32Value;
+import com.google.protobuf.Int64Value;
import com.google.protobuf.Message;
+import com.google.protobuf.StringValue;
+import com.google.protobuf.UInt32Value;
+import com.google.protobuf.UInt64Value;
+import com.google.protobuf.util.Timestamps;
import com.twitter.elephantbird.util.Protobufs;
import org.apache.hadoop.conf.Configuration;
import org.apache.parquet.column.Dictionary;
@@ -42,6 +53,8 @@ import org.apache.parquet.schema.PrimitiveType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import java.time.LocalDate;
+import java.time.LocalTime;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -233,6 +246,21 @@ class ProtoMessageConverter extends GroupConverter {
public Optional<Converter>
visit(LogicalTypeAnnotation.MapLogicalTypeAnnotation mapLogicalType) {
return of(new MapConverter(parentBuilder, fieldDescriptor,
parquetType));
}
+
+ @Override
+ public Optional<Converter>
visit(LogicalTypeAnnotation.TimestampLogicalTypeAnnotation
timestampLogicalType) {
+ return of(new ProtoTimestampConverter(parent, timestampLogicalType));
+ }
+
+ @Override
+ public Optional<Converter>
visit(LogicalTypeAnnotation.DateLogicalTypeAnnotation dateLogicalType) {
+ return of(new ProtoDateConverter(parent));
+ }
+
+ @Override
+ public Optional<Converter>
visit(LogicalTypeAnnotation.TimeLogicalTypeAnnotation timeLogicalType) {
+ return of(new ProtoTimeConverter(parent, timeLogicalType));
+ }
}).orElseGet(() -> newScalarConverter(parent, parentBuilder,
fieldDescriptor, parquetType));
}
@@ -250,6 +278,37 @@ class ProtoMessageConverter extends GroupConverter {
case INT: return new ProtoIntConverter(pvc);
case LONG: return new ProtoLongConverter(pvc);
case MESSAGE: {
+ if (parquetType.isPrimitive()) {
+ // if source is a Primitive type yet target is MESSAGE, it's
probably a wrapped message
+ Descriptor messageType = fieldDescriptor.getMessageType();
+ if (messageType.equals(DoubleValue.getDescriptor())) {
+ return new ProtoDoubleValueConverter(pvc);
+ }
+ if (messageType.equals(FloatValue.getDescriptor())) {
+ return new ProtoFloatValueConverter(pvc);
+ }
+ if (messageType.equals(Int64Value.getDescriptor())) {
+ return new ProtoInt64ValueConverter(pvc);
+ }
+ if (messageType.equals(UInt64Value.getDescriptor())) {
+ return new ProtoUInt64ValueConverter(pvc);
+ }
+ if (messageType.equals(Int32Value.getDescriptor())) {
+ return new ProtoInt32ValueConverter(pvc);
+ }
+ if (messageType.equals(UInt32Value.getDescriptor())) {
+ return new ProtoUInt32ValueConverter(pvc);
+ }
+ if (messageType.equals(BoolValue.getDescriptor())) {
+ return new ProtoBoolValueConverter(pvc);
+ }
+ if (messageType.equals(StringValue.getDescriptor())) {
+ return new ProtoStringValueConverter(pvc);
+ }
+ if (messageType.equals(BytesValue.getDescriptor())) {
+ return new ProtoBytesValueConverter(pvc);
+ }
+ }
Message.Builder subBuilder =
parentBuilder.newBuilderForField(fieldDescriptor);
return new ProtoMessageConverter(conf, pvc, subBuilder,
parquetType.asGroupType(), extraMetadata);
}
@@ -295,7 +354,7 @@ class ProtoMessageConverter extends GroupConverter {
* Fills lookup structure for translating between parquet enum values and
Protocol buffer enum values.
* */
private Map<Binary, Descriptors.EnumValueDescriptor>
makeLookupStructure(Descriptors.EnumDescriptor enumType) {
- Map<Binary, Descriptors.EnumValueDescriptor> lookupStructure = new
HashMap<Binary, Descriptors.EnumValueDescriptor>();
+ Map<Binary, Descriptors.EnumValueDescriptor> lookupStructure = new
HashMap<>();
if (extraMetadata.containsKey(METADATA_ENUM_PREFIX +
enumType.getFullName())) {
String enumNameNumberPairs = extraMetadata.get(METADATA_ENUM_PREFIX +
enumType.getFullName());
@@ -366,7 +425,7 @@ class ProtoMessageConverter extends GroupConverter {
}
@Override
- final public void addBinary(Binary binaryValue) {
+ public void addBinary(Binary binaryValue) {
Descriptors.EnumValueDescriptor protoValue =
translateEnumValue(binaryValue);
parent.add(protoValue);
}
@@ -392,7 +451,7 @@ class ProtoMessageConverter extends GroupConverter {
}
- final class ProtoBinaryConverter extends PrimitiveConverter {
+ static final class ProtoBinaryConverter extends PrimitiveConverter {
final ParentValueContainer parent;
@@ -408,7 +467,7 @@ class ProtoMessageConverter extends GroupConverter {
}
- final class ProtoBooleanConverter extends PrimitiveConverter {
+ static final class ProtoBooleanConverter extends PrimitiveConverter {
final ParentValueContainer parent;
@@ -417,13 +476,13 @@ class ProtoMessageConverter extends GroupConverter {
}
@Override
- final public void addBoolean(boolean value) {
+ public void addBoolean(boolean value) {
parent.add(value);
}
}
- final class ProtoDoubleConverter extends PrimitiveConverter {
+ static final class ProtoDoubleConverter extends PrimitiveConverter {
final ParentValueContainer parent;
@@ -437,7 +496,7 @@ class ProtoMessageConverter extends GroupConverter {
}
}
- final class ProtoFloatConverter extends PrimitiveConverter {
+ static final class ProtoFloatConverter extends PrimitiveConverter {
final ParentValueContainer parent;
@@ -451,7 +510,7 @@ class ProtoMessageConverter extends GroupConverter {
}
}
- final class ProtoIntConverter extends PrimitiveConverter {
+ static final class ProtoIntConverter extends PrimitiveConverter {
final ParentValueContainer parent;
@@ -465,7 +524,7 @@ class ProtoMessageConverter extends GroupConverter {
}
}
- final class ProtoLongConverter extends PrimitiveConverter {
+ static final class ProtoLongConverter extends PrimitiveConverter {
final ParentValueContainer parent;
@@ -479,7 +538,7 @@ class ProtoMessageConverter extends GroupConverter {
}
}
- final class ProtoStringConverter extends PrimitiveConverter {
+ static final class ProtoStringConverter extends PrimitiveConverter {
final ParentValueContainer parent;
@@ -495,6 +554,218 @@ class ProtoMessageConverter extends GroupConverter {
}
+ static final class ProtoTimestampConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+ final LogicalTypeAnnotation.TimestampLogicalTypeAnnotation
logicalTypeAnnotation;
+
+ public ProtoTimestampConverter(ParentValueContainer parent,
LogicalTypeAnnotation.TimestampLogicalTypeAnnotation logicalTypeAnnotation) {
+ this.parent = parent;
+ this.logicalTypeAnnotation = logicalTypeAnnotation;
+ }
+
+ @Override
+ public void addLong(long value) {
+ switch (logicalTypeAnnotation.getUnit()) {
+ case MICROS:
+ parent.add(Timestamps.fromMicros(value));
+ break;
+ case MILLIS:
+ parent.add(Timestamps.fromMillis(value));
+ break;
+ case NANOS:
+ parent.add(Timestamps.fromNanos(value));
+ break;
+ }
+ }
+ }
+
+ static final class ProtoDateConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoDateConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addInt(int value) {
+ LocalDate localDate = LocalDate.ofEpochDay(value);
+ com.google.type.Date date = com.google.type.Date.newBuilder()
+ .setYear(localDate.getYear())
+ .setMonth(localDate.getMonthValue())
+ .setDay(localDate.getDayOfMonth())
+ .build();
+ parent.add(date);
+ }
+ }
+
+ static final class ProtoTimeConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+ final LogicalTypeAnnotation.TimeLogicalTypeAnnotation
logicalTypeAnnotation;
+
+ public ProtoTimeConverter(ParentValueContainer parent,
LogicalTypeAnnotation.TimeLogicalTypeAnnotation logicalTypeAnnotation) {
+ this.parent = parent;
+ this.logicalTypeAnnotation = logicalTypeAnnotation;
+ }
+
+ @Override
+ public void addLong(long value) {
+ LocalTime localTime;
+ switch (logicalTypeAnnotation.getUnit()) {
+ case MILLIS:
+ localTime = LocalTime.ofNanoOfDay(value * 1_000_000);
+ break;
+ case MICROS:
+ localTime = LocalTime.ofNanoOfDay(value * 1_000);
+ break;
+ case NANOS:
+ localTime = LocalTime.ofNanoOfDay(value);
+ break;
+ default:
+ throw new IllegalArgumentException("Unrecognized TimeUnit: " +
logicalTypeAnnotation.getUnit());
+ }
+ com.google.type.TimeOfDay timeOfDay =
com.google.type.TimeOfDay.newBuilder()
+ .setHours(localTime.getHour())
+ .setMinutes(localTime.getMinute())
+ .setSeconds(localTime.getSecond())
+ .setNanos(localTime.getNano())
+ .build();
+ parent.add(timeOfDay);
+ }
+ }
+
+ static final class ProtoDoubleValueConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoDoubleValueConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addDouble(double value) {
+ parent.add(DoubleValue.of(value));
+ }
+ }
+
+ static final class ProtoFloatValueConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoFloatValueConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addFloat(float value) {
+ parent.add(FloatValue.of(value));
+ }
+ }
+
+ static final class ProtoInt64ValueConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoInt64ValueConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addLong(long value) {
+ parent.add(Int64Value.of(value));
+ }
+ }
+
+ static final class ProtoUInt64ValueConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoUInt64ValueConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addLong(long value) {
+ parent.add(UInt64Value.of(value));
+ }
+ }
+
+ static final class ProtoInt32ValueConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoInt32ValueConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addInt(int value) {
+ parent.add(Int32Value.of(value));
+ }
+ }
+
+ static final class ProtoUInt32ValueConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoUInt32ValueConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addLong(long value) {
+ parent.add(UInt32Value.of(Math.toIntExact(value)));
+ }
+ }
+
+ static final class ProtoBoolValueConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoBoolValueConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addBoolean(boolean value) {
+ parent.add(BoolValue.of(value));
+ }
+
+ }
+
+ static final class ProtoStringValueConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoStringValueConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addBinary(Binary binary) {
+ String str = binary.toStringUsingUTF8();
+ parent.add(StringValue.of(str));
+ }
+
+ }
+
+ static final class ProtoBytesValueConverter extends PrimitiveConverter {
+
+ final ParentValueContainer parent;
+
+ public ProtoBytesValueConverter(ParentValueContainer parent) {
+ this.parent = parent;
+ }
+
+ @Override
+ public void addBinary(Binary binary) {
+ ByteString byteString = ByteString.copyFrom(binary.toByteBuffer());
+ parent.add(BytesValue.of(byteString));
+ }
+ }
+
/**
* This class unwraps the additional LIST wrapper and makes it possible to
read the underlying data and then convert
* it to protobuf.
diff --git
a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java
b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java
index a6a779d07..83f3970c2 100644
---
a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java
+++
b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java
@@ -18,11 +18,24 @@
*/
package org.apache.parquet.proto;
+import com.google.protobuf.BoolValue;
+import com.google.protobuf.BytesValue;
import com.google.common.collect.ImmutableSetMultimap;
import com.google.protobuf.Descriptors;
+import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType;
+import com.google.protobuf.DoubleValue;
+import com.google.protobuf.FloatValue;
+import com.google.protobuf.Int32Value;
+import com.google.protobuf.Int64Value;
import com.google.protobuf.Message;
+import com.google.protobuf.StringValue;
+import com.google.protobuf.Timestamp;
+import com.google.protobuf.UInt32Value;
+import com.google.protobuf.UInt64Value;
+import com.google.type.Date;
+import com.google.type.TimeOfDay;
import com.twitter.elephantbird.util.Protobufs;
import org.apache.hadoop.conf.Configuration;
import org.apache.parquet.conf.HadoopParquetConfiguration;
@@ -31,6 +44,7 @@ import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName;
import org.apache.parquet.schema.Type;
+import org.apache.parquet.schema.Type.Repetition;
import org.apache.parquet.schema.Types;
import org.apache.parquet.schema.Types.Builder;
import org.apache.parquet.schema.Types.GroupBuilder;
@@ -40,11 +54,20 @@ import org.slf4j.LoggerFactory;
import java.util.List;
import javax.annotation.Nullable;
+import static org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit;
+import static org.apache.parquet.schema.LogicalTypeAnnotation.dateType;
import static org.apache.parquet.schema.LogicalTypeAnnotation.enumType;
import static org.apache.parquet.schema.LogicalTypeAnnotation.listType;
import static org.apache.parquet.schema.LogicalTypeAnnotation.mapType;
import static org.apache.parquet.schema.LogicalTypeAnnotation.stringType;
-import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.*;
+import static org.apache.parquet.schema.LogicalTypeAnnotation.timeType;
+import static org.apache.parquet.schema.LogicalTypeAnnotation.timestampType;
+import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY;
+import static
org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BOOLEAN;
+import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE;
+import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.FLOAT;
+import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32;
+import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64;
/**
* Converts a Protocol Buffer Descriptor into a Parquet schema.
@@ -55,6 +78,7 @@ public class ProtoSchemaConverter {
public static final String PB_MAX_RECURSION = "parquet.proto.maxRecursion";
private final boolean parquetSpecsCompliant;
+ private final boolean unwrapProtoWrappers;
// TODO: use proto custom options to override per field.
private final int maxRecursion;
@@ -76,7 +100,7 @@ public class ProtoSchemaConverter {
* by the parquet specifications. If set to
true, specs compliant schemas are used.
*/
public ProtoSchemaConverter(boolean parquetSpecsCompliant) {
- this(parquetSpecsCompliant, 5);
+ this(parquetSpecsCompliant, 5, false);
}
/**
@@ -98,7 +122,9 @@ public class ProtoSchemaConverter {
public ProtoSchemaConverter(ParquetConfiguration config) {
this(
config.getBoolean(ProtoWriteSupport.PB_SPECS_COMPLIANT_WRITE, false),
- config.getInt(PB_MAX_RECURSION, 5));
+ config.getInt(PB_MAX_RECURSION, 5),
+ config.getBoolean(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, false)
+ );
}
/**
@@ -112,8 +138,25 @@ public class ProtoSchemaConverter {
* bytes instead of their actual schema.
*/
public ProtoSchemaConverter(boolean parquetSpecsCompliant, int maxRecursion)
{
+ this(parquetSpecsCompliant, maxRecursion, false);
+ }
+
+ /**
+ * Instantiate a schema converter to get the parquet schema corresponding to
protobuf classes.
+ *
+ * @param parquetSpecsCompliant If set to false, the parquet schema
generated will be using the old
+ * schema style (prior to PARQUET-968) to
provide backward-compatibility
+ * but which does not use LIST and MAP
wrappers around collections as required
+ * by the parquet specifications. If set to
true, specs compliant schemas are used.
+ * @param maxRecursion The maximum recursion depth messages are
allowed to go before terminating as
+ * bytes instead of their actual schema.
+ * @param unwrapProtoWrappers If set to true, unwrap common Proto
wrappers like Timestamp and DoubleValue
+ * with corresponding OPTIONAL logical
annotations. Primitive types become REQUIRED.
+ */
+ public ProtoSchemaConverter(boolean parquetSpecsCompliant, int maxRecursion,
boolean unwrapProtoWrappers) {
this.parquetSpecsCompliant = parquetSpecsCompliant;
this.maxRecursion = maxRecursion;
+ this.unwrapProtoWrappers = unwrapProtoWrappers;
}
/**
@@ -178,6 +221,46 @@ public class ProtoSchemaConverter {
private <T> Builder<? extends Builder<?, GroupBuilder<T>>, GroupBuilder<T>>
addField(FieldDescriptor descriptor, final GroupBuilder<T> builder,
ImmutableSetMultimap<String, Integer> seen, int depth) {
if (descriptor.getJavaType() == JavaType.MESSAGE) {
+ if (unwrapProtoWrappers) {
+ Descriptor messageType = descriptor.getMessageType();
+ if (messageType.equals(Timestamp.getDescriptor())) {
+ return builder.primitive(INT64,
getRepetition(descriptor)).as(timestampType(true, TimeUnit.NANOS));
+ }
+ if (messageType.equals(Date.getDescriptor())) {
+ return builder.primitive(INT32,
getRepetition(descriptor)).as(dateType());
+ }
+ if (messageType.equals(TimeOfDay.getDescriptor())) {
+ return builder.primitive(INT64,
getRepetition(descriptor)).as(timeType(true, TimeUnit.NANOS));
+ }
+ if (messageType.equals(DoubleValue.getDescriptor())) {
+ return builder.primitive(DOUBLE, getRepetition(descriptor));
+ }
+ if (messageType.equals(StringValue.getDescriptor())) {
+ return builder.primitive(BINARY,
getRepetition(descriptor)).as(stringType());
+ }
+ if (messageType.equals(BoolValue.getDescriptor())) {
+ return builder.primitive(BOOLEAN, getRepetition(descriptor));
+ }
+ if (messageType.equals(FloatValue.getDescriptor())) {
+ return builder.primitive(FLOAT, getRepetition(descriptor));
+ }
+ if (messageType.equals(Int64Value.getDescriptor())) {
+ return builder.primitive(INT64, getRepetition(descriptor));
+ }
+ if (messageType.equals(UInt64Value.getDescriptor())) {
+ return builder.primitive(INT64, getRepetition(descriptor));
+ }
+ if (messageType.equals(Int32Value.getDescriptor())) {
+ return builder.primitive(INT32, getRepetition(descriptor));
+ }
+ if (messageType.equals(UInt32Value.getDescriptor())) {
+ return builder.primitive(INT64, getRepetition(descriptor));
+ }
+ if (messageType.equals(BytesValue.getDescriptor())) {
+ return builder.primitive(BINARY, getRepetition(descriptor));
+ }
+ }
+
return addMessageField(descriptor, builder, seen, depth);
}
@@ -186,8 +269,8 @@ public class ProtoSchemaConverter {
// the old schema style did not include the LIST wrapper around repeated
fields
return addRepeatedPrimitive(parquetType.primitiveType,
parquetType.logicalTypeAnnotation, builder);
}
-
- return builder.primitive(parquetType.primitiveType,
getRepetition(descriptor)).as(parquetType.logicalTypeAnnotation);
+ Repetition repetition = unwrapProtoWrappers ? Repetition.REQUIRED :
getRepetition(descriptor);
+ return builder.primitive(parquetType.primitiveType,
repetition).as(parquetType.logicalTypeAnnotation);
}
private static <T> Builder<? extends Builder<?, GroupBuilder<T>>,
GroupBuilder<T>> addRepeatedPrimitive(PrimitiveTypeName primitiveType,
diff --git
a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java
b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java
index b13acd2a5..c5081e759 100644
---
a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java
+++
b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java
@@ -21,6 +21,9 @@ package org.apache.parquet.proto;
import com.google.protobuf.*;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
+import com.google.protobuf.util.Timestamps;
+import com.google.type.Date;
+import com.google.type.TimeOfDay;
import com.twitter.elephantbird.util.Protobufs;
import org.apache.hadoop.conf.Configuration;
import org.apache.parquet.conf.HadoopParquetConfiguration;
@@ -36,6 +39,8 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.reflect.Array;
+import java.time.LocalDate;
+import java.time.LocalTime;
import java.util.*;
import static java.util.Optional.ofNullable;
@@ -57,7 +62,10 @@ public class ProtoWriteSupport<T extends MessageOrBuilder>
extends WriteSupport<
// but is set to false by default to keep backward compatibility.
public static final String PB_SPECS_COMPLIANT_WRITE =
"parquet.proto.writeSpecsCompliant";
+ public static final String PB_UNWRAP_PROTO_WRAPPERS =
"parquet.proto.unwrapProtoWrappers";
+
private boolean writeSpecsCompliant = false;
+ private boolean unwrapProtoWrappers = false;
private RecordConsumer recordConsumer;
private Class<? extends Message> protoMessage;
private Descriptor descriptor;
@@ -96,6 +104,10 @@ public class ProtoWriteSupport<T extends MessageOrBuilder>
extends WriteSupport<
configuration.setBoolean(PB_SPECS_COMPLIANT_WRITE, writeSpecsCompliant);
}
+ public static void setUnwrapProtoWrappers(Configuration configuration,
boolean unwrapProtoWrappers) {
+ configuration.setBoolean(PB_UNWRAP_PROTO_WRAPPERS, unwrapProtoWrappers);
+ }
+
/**
* Writes Protocol buffer to parquet file.
* @param record instance of Message.Builder or Message.
@@ -144,6 +156,7 @@ public class ProtoWriteSupport<T extends MessageOrBuilder>
extends WriteSupport<
extraMetaData.put(ProtoReadSupport.PB_CLASS, protoMessage.getName());
}
+ unwrapProtoWrappers = configuration.getBoolean(PB_UNWRAP_PROTO_WRAPPERS,
unwrapProtoWrappers);
writeSpecsCompliant = configuration.getBoolean(PB_SPECS_COMPLIANT_WRITE,
writeSpecsCompliant);
MessageType rootSchema = new
ProtoSchemaConverter(configuration).convert(descriptor);
validatedMapping(descriptor, rootSchema);
@@ -152,6 +165,7 @@ public class ProtoWriteSupport<T extends MessageOrBuilder>
extends WriteSupport<
extraMetaData.put(ProtoReadSupport.PB_DESCRIPTOR,
descriptor.toProto().toString());
extraMetaData.put(PB_SPECS_COMPLIANT_WRITE,
String.valueOf(writeSpecsCompliant));
+ extraMetaData.put(PB_UNWRAP_PROTO_WRAPPERS,
String.valueOf(unwrapProtoWrappers));
return new WriteContext(rootSchema, extraMetaData);
}
@@ -265,6 +279,46 @@ public class ProtoWriteSupport<T extends MessageOrBuilder>
extends WriteSupport<
return createMapWriter(fieldDescriptor, type);
}
+ if (unwrapProtoWrappers) {
+ Descriptor messageType = fieldDescriptor.getMessageType();
+ if (messageType.equals(Timestamp.getDescriptor())) {
+ return new TimestampWriter();
+ }
+ if (messageType.equals(Date.getDescriptor())) {
+ return new DateWriter();
+ }
+ if (messageType.equals(TimeOfDay.getDescriptor())) {
+ return new TimeWriter();
+ }
+ if (messageType.equals(DoubleValue.getDescriptor())) {
+ return new DoubleValueWriter();
+ }
+ if (messageType.equals(FloatValue.getDescriptor())) {
+ return new FloatValueWriter();
+ }
+ if (messageType.equals(Int64Value.getDescriptor())) {
+ return new Int64ValueWriter();
+ }
+ if (messageType.equals(UInt64Value.getDescriptor())) {
+ return new UInt64ValueWriter();
+ }
+ if (messageType.equals(Int32Value.getDescriptor())) {
+ return new Int32ValueWriter();
+ }
+ if (messageType.equals(UInt32Value.getDescriptor())) {
+ return new UInt32ValueWriter();
+ }
+ if (messageType.equals(BoolValue.getDescriptor())) {
+ return new BoolValueWriter();
+ }
+ if (messageType.equals(StringValue.getDescriptor())) {
+ return new StringValueWriter();
+ }
+ if (messageType.equals(BytesValue.getDescriptor())) {
+ return new BytesValueWriter();
+ }
+ }
+
// This can happen now that recursive schemas get truncated to bytes.
Write the bytes.
if (type.isPrimitive() && type.asPrimitiveType().getPrimitiveTypeName()
== PrimitiveType.PrimitiveTypeName.BINARY) {
return new BinaryWriter();
@@ -584,6 +638,98 @@ public class ProtoWriteSupport<T extends MessageOrBuilder>
extends WriteSupport<
}
}
+ class TimestampWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ Timestamp timestamp = (Timestamp) value;
+ recordConsumer.addLong(Timestamps.toNanos(timestamp));
+ }
+ }
+
+ class DateWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ Date date = (Date) value;
+ LocalDate localDate = LocalDate.of(date.getYear(), date.getMonth(),
date.getDay());
+ recordConsumer.addInteger((int) localDate.toEpochDay());
+ }
+ }
+
+ class TimeWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ com.google.type.TimeOfDay timeOfDay = (com.google.type.TimeOfDay) value;
+ LocalTime localTime = LocalTime.of(timeOfDay.getHours(),
timeOfDay.getMinutes(), timeOfDay.getSeconds(), timeOfDay.getNanos());
+ recordConsumer.addLong(localTime.toNanoOfDay());
+ }
+ }
+
+ class DoubleValueWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ recordConsumer.addDouble(((DoubleValue) value).getValue());
+ }
+ }
+
+ class FloatValueWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ recordConsumer.addFloat(((FloatValue) value).getValue());
+ }
+ }
+
+ class Int64ValueWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ recordConsumer.addLong(((Int64Value) value).getValue());
+ }
+ }
+
+ class UInt64ValueWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ recordConsumer.addLong(((UInt64Value) value).getValue());
+ }
+ }
+
+ class Int32ValueWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ recordConsumer.addInteger(((Int32Value) value).getValue());
+ }
+ }
+
+ class UInt32ValueWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ recordConsumer.addLong(((UInt32Value) value).getValue());
+ }
+ }
+
+ class BoolValueWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ recordConsumer.addBoolean(((BoolValue) value).getValue());
+ }
+ }
+
+ class StringValueWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ Binary binaryString = Binary.fromString(((StringValue)
value).getValue());
+ recordConsumer.addBinary(binaryString);
+ }
+ }
+
+ class BytesValueWriter extends FieldWriter {
+ @Override
+ void writeRawValue(Object value) {
+ byte[] byteArray = ((BytesValue) value).getValue().toByteArray();
+ Binary binary = Binary.fromConstantByteArray(byteArray);
+ recordConsumer.addBinary(binary);
+ }
+ }
+
private FieldWriter unknownType(FieldDescriptor fieldDescriptor) {
String exceptionMsg = "Unknown type with descriptor \"" + fieldDescriptor
+ "\" and type \"" + fieldDescriptor.getJavaType() + "\".";
diff --git
a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java
b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java
index 4debe77f8..605c32266 100644
---
a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java
+++
b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java
@@ -18,8 +18,12 @@
*/
package org.apache.parquet.proto;
+import com.google.protobuf.BoolValue;
import com.google.protobuf.ByteString;
+import com.google.protobuf.DoubleValue;
import com.google.protobuf.Message;
+import com.google.protobuf.Timestamp;
+import com.google.protobuf.util.Timestamps;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.parquet.proto.test.TestProto3;
@@ -32,7 +36,9 @@ import org.junit.Test;
import java.util.List;
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
public class ProtoInputOutputFormatTest {
@@ -607,6 +613,49 @@ public class ProtoInputOutputFormatTest {
assertTrue(third.getTwo().isEmpty());
}
+ @Test
+ public void testProto3TimestampMessageClass() throws Exception {
+ Timestamp timestamp = Timestamps.parse("2021-05-02T15:04:03.748Z");
+ TestProto3.DateTimeMessage msgEmpty =
TestProto3.DateTimeMessage.newBuilder().build();
+ TestProto3.DateTimeMessage msgNonEmpty =
TestProto3.DateTimeMessage.newBuilder()
+ .setTimestamp(timestamp)
+ .build();
+
+ Configuration conf = new Configuration();
+ conf.setBoolean(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, true);
+ Path outputPath = new WriteUsingMR(conf).write(msgEmpty, msgNonEmpty);
+ ReadUsingMR readUsingMR = new ReadUsingMR();
+ String customClass = TestProto3.DateTimeMessage.class.getName();
+ ProtoReadSupport.setProtobufClass(readUsingMR.getConfiguration(),
customClass);
+ List<Message> result = readUsingMR.read(outputPath);
+
+ assertEquals(2, result.size());
+ assertEquals(msgEmpty, result.get(0));
+ assertEquals(msgNonEmpty, result.get(1));
+ }
+
+ @Test
+ public void testProto3WrappedMessageClass() throws Exception {
+ TestProto3.WrappedMessage msgEmpty =
TestProto3.WrappedMessage.newBuilder().build();
+ TestProto3.WrappedMessage msgNonEmpty =
TestProto3.WrappedMessage.newBuilder()
+ .setWrappedDouble(DoubleValue.of(0.577))
+ .setWrappedBool(BoolValue.of(true))
+ .build();
+
+
+ Configuration conf = new Configuration();
+ conf.setBoolean(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, true);
+ Path outputPath = new WriteUsingMR(conf).write(msgEmpty, msgNonEmpty);
+ ReadUsingMR readUsingMR = new ReadUsingMR();
+ String customClass = TestProto3.WrappedMessage.class.getName();
+ ProtoReadSupport.setProtobufClass(readUsingMR.getConfiguration(),
customClass);
+ List<Message> result = readUsingMR.read(outputPath);
+
+ assertEquals(2, result.size());
+ assertEquals(msgEmpty, result.get(0));
+ assertEquals(msgNonEmpty, result.get(1));
+ }
+
/**
* Runs job that writes input to file and then job reading data back.
*/
diff --git
a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java
b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java
index 8c64197c3..287159002 100644
---
a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java
+++
b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java
@@ -23,7 +23,6 @@ import com.google.protobuf.Message;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import org.junit.Test;
-import org.apache.parquet.proto.TestUtils;
import org.apache.parquet.proto.test.TestProto3;
import org.apache.parquet.proto.test.TestProtobuf;
import org.apache.parquet.proto.test.Trees;
@@ -40,12 +39,12 @@ public class ProtoSchemaConverterTest {
/**
* Converts given pbClass to parquet schema and compares it with expected
parquet schema.
*/
- private static void testConversion(Class<? extends Message> pbClass, String
parquetSchemaString, boolean parquetSpecsCompliant) {
- testConversion(pbClass, parquetSchemaString, new
ProtoSchemaConverter(parquetSpecsCompliant));
+ private static void testConversion(Class<? extends Message> pbClass, String
parquetSchemaString, boolean parquetSpecsCompliant, boolean unwrapWrappers) {
+ testConversion(pbClass, parquetSchemaString, new
ProtoSchemaConverter(parquetSpecsCompliant, 5, unwrapWrappers));
}
private static void testConversion(Class<? extends Message> pbClass, String
parquetSchemaString) {
- testConversion(pbClass, parquetSchemaString, true);
+ testConversion(pbClass, parquetSchemaString, true, false);
}
private static void testConversion(Class<? extends Message> pbClass, String
parquetSchemaString, ProtoSchemaConverter converter) {
@@ -54,6 +53,9 @@ public class ProtoSchemaConverterTest {
assertEquals(expectedMT.toString(), schema.toString());
}
+ private void testConversion(Class<? extends Message> pbClass, String
parquetSchemaString, boolean parquetSpecsCompliant) throws Exception {
+ testConversion(pbClass, parquetSchemaString, parquetSpecsCompliant, false);
+ }
/**
* Tests that all protocol buffer datatypes are converted to correct parquet
datatypes.
@@ -206,7 +208,7 @@ public class ProtoSchemaConverterTest {
" repeated int32 repeatedInt = 1;",
"}");
- testConversion(TestProtobuf.RepeatedIntMessage.class, expectedSchema,
false);
+ testConversion(TestProtobuf.RepeatedIntMessage.class, expectedSchema,
false, false);
}
@Test
@@ -231,7 +233,7 @@ public class ProtoSchemaConverterTest {
" repeated int32 repeatedInt = 1;",
"}");
- testConversion(TestProto3.RepeatedIntMessage.class, expectedSchema, false);
+ testConversion(TestProto3.RepeatedIntMessage.class, expectedSchema, false,
false);
}
@Test
@@ -263,7 +265,7 @@ public class ProtoSchemaConverterTest {
" }",
"}");
- testConversion(TestProtobuf.RepeatedInnerMessage.class, expectedSchema,
false);
+ testConversion(TestProtobuf.RepeatedInnerMessage.class, expectedSchema,
false, false);
}
@Test
@@ -295,7 +297,7 @@ public class ProtoSchemaConverterTest {
" }",
"}");
- testConversion(TestProto3.RepeatedInnerMessage.class, expectedSchema,
false);
+ testConversion(TestProto3.RepeatedInnerMessage.class, expectedSchema,
false, false);
}
@Test
@@ -323,7 +325,7 @@ public class ProtoSchemaConverterTest {
" }",
"}");
- testConversion(TestProtobuf.MapIntMessage.class, expectedSchema, false);
+ testConversion(TestProtobuf.MapIntMessage.class, expectedSchema, false,
false);
}
@Test
@@ -351,7 +353,61 @@ public class ProtoSchemaConverterTest {
" }",
"}");
- testConversion(TestProto3.MapIntMessage.class, expectedSchema, false);
+ testConversion(TestProto3.MapIntMessage.class, expectedSchema, false,
false);
+ }
+
+ @Test
+ public void testProto3ConvertDateTimeMessageWrapped() throws Exception {
+ String expectedSchema =
+ "message TestProto3.DateTimeMessage {\n" +
+ " optional group timestamp = 1 {\n" +
+ " optional int64 seconds = 1;\n" +
+ " optional int32 nanos = 2;\n" +
+ " }\n" +
+ " optional group date = 2 {\n" +
+ " optional int32 year = 1;\n" +
+ " optional int32 month = 2;\n" +
+ " optional int32 day = 3;\n" +
+ " }\n" +
+ " optional group time = 3 {\n" +
+ " optional int32 hours = 1;\n" +
+ " optional int32 minutes = 2;\n" +
+ " optional int32 seconds = 3;\n" +
+ " optional int32 nanos = 4;\n" +
+ " }\n" +
+ "}";
+
+ testConversion(TestProto3.DateTimeMessage.class, expectedSchema, false,
false);
+ }
+
+ @Test
+ public void testProto3ConvertDateTimeMessageUnwrapped() throws Exception {
+ String expectedSchema =
+ "message TestProto3.DateTimeMessage {\n" +
+ " optional int64 timestamp (TIMESTAMP(NANOS,true)) = 1;\n" +
+ " optional int32 date (DATE) = 2;\n" +
+ " optional int64 time (TIME(NANOS,true)) = 3;\n" +
+ "}";
+
+ testConversion(TestProto3.DateTimeMessage.class, expectedSchema, false,
true);
+ }
+
+ @Test
+ public void testProto3ConvertWrappedMessageUnwrapped() throws Exception {
+ String expectedSchema =
+ "message TestProto3.WrappedMessage {\n" +
+ " optional double wrappedDouble = 1;\n" +
+ " optional float wrappedFloat = 2;\n" +
+ " optional int64 wrappedInt64 = 3;\n" +
+ " optional int64 wrappedUInt64 = 4;\n" +
+ " optional int32 wrappedInt32 = 5;\n" +
+ " optional int64 wrappedUInt32 = 6;\n" +
+ " optional boolean wrappedBool = 7;\n" +
+ " optional binary wrappedString (UTF8) = 8;\n" +
+ " optional binary wrappedBytes = 9;\n" +
+ "}";
+
+ testConversion(TestProto3.WrappedMessage.class, expectedSchema, false,
true);
}
@Test
@@ -379,8 +435,8 @@ public class ProtoSchemaConverterTest {
" optional binary right = 3;",
" }",
"}");
- testConversion(Trees.BinaryTree.class, expectedSchema, new
ProtoSchemaConverter(true, 1));
- testConversion(Trees.BinaryTree.class,
TestUtils.readResource("BinaryTree.par"), new ProtoSchemaConverter(true,
PAR_RECURSION_DEPTH));
+ testConversion(Trees.BinaryTree.class, expectedSchema, new
ProtoSchemaConverter(true, 1, false));
+ testConversion(Trees.BinaryTree.class,
TestUtils.readResource("BinaryTree.par"), new ProtoSchemaConverter(true,
PAR_RECURSION_DEPTH, false));
}
@@ -404,8 +460,8 @@ public class ProtoSchemaConverterTest {
" }",
" }",
"}");
- testConversion(Trees.WideTree.class, expectedSchema, new
ProtoSchemaConverter(true, 1));
- testConversion(Trees.WideTree.class,
TestUtils.readResource("WideTree.par"), new ProtoSchemaConverter(true,
PAR_RECURSION_DEPTH));
+ testConversion(Trees.WideTree.class, expectedSchema, new
ProtoSchemaConverter(true, 1, false));
+ testConversion(Trees.WideTree.class,
TestUtils.readResource("WideTree.par"), new ProtoSchemaConverter(true,
PAR_RECURSION_DEPTH, false));
}
@@ -455,8 +511,8 @@ public class ProtoSchemaConverterTest {
" }",
" }",
"}");
- testConversion(Value.class, expectedSchema, new ProtoSchemaConverter(true,
1));
- testConversion(Value.class, TestUtils.readResource("Value.par"), new
ProtoSchemaConverter(true, PAR_RECURSION_DEPTH));
+ testConversion(Value.class, expectedSchema, new ProtoSchemaConverter(true,
1, false));
+ testConversion(Value.class, TestUtils.readResource("Value.par"), new
ProtoSchemaConverter(true, PAR_RECURSION_DEPTH, false));
}
@Test
@@ -510,8 +566,8 @@ public class ProtoSchemaConverterTest {
" }",
" }",
"}");
- testConversion(Struct.class, expectedSchema, new
ProtoSchemaConverter(true, 1));
- testConversion(Struct.class, TestUtils.readResource("Struct.par"), new
ProtoSchemaConverter(true, PAR_RECURSION_DEPTH));
+ testConversion(Struct.class, expectedSchema, new
ProtoSchemaConverter(true, 1, false));
+ testConversion(Struct.class, TestUtils.readResource("Struct.par"), new
ProtoSchemaConverter(true, PAR_RECURSION_DEPTH, false));
}
@Test
@@ -521,16 +577,16 @@ public class ProtoSchemaConverterTest {
long expectedBinaryTreeSize = 4;
long expectedStructSize = 7;
for (int i = 0; i < 10; ++i) {
- MessageType deepSchema = new ProtoSchemaConverter(true,
i).convert(Trees.WideTree.class);
+ MessageType deepSchema = new ProtoSchemaConverter(true, i,
false).convert(Trees.WideTree.class);
// 3, 5, 7, 9, 11, 13, 15, 17, 19, 21
assertEquals(2 * i + 3, deepSchema.getPaths().size());
- deepSchema = new ProtoSchemaConverter(true,
i).convert(Trees.BinaryTree.class);
+ deepSchema = new ProtoSchemaConverter(true, i,
false).convert(Trees.BinaryTree.class);
// 4, 10, 22, 46, 94, 190, 382, 766, 1534, 3070
assertEquals(expectedBinaryTreeSize, deepSchema.getPaths().size());
expectedBinaryTreeSize = 2 * expectedBinaryTreeSize + 2;
- deepSchema = new ProtoSchemaConverter(true, i).convert(Struct.class);
+ deepSchema = new ProtoSchemaConverter(true, i,
false).convert(Struct.class);
// 7, 18, 40, 84, 172, 348, 700, 1404, 2812, 5628
assertEquals(expectedStructSize, deepSchema.getPaths().size());
expectedStructSize = 2 * expectedStructSize + 4;
diff --git
a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java
b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java
index ef0356d01..c4c34c900 100644
---
a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java
+++
b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java
@@ -18,25 +18,44 @@
*/
package org.apache.parquet.proto;
+import com.google.protobuf.BoolValue;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.BytesValue;
+import com.google.protobuf.DoubleValue;
+import com.google.protobuf.FloatValue;
+import com.google.protobuf.Int32Value;
+import com.google.protobuf.Int64Value;
import com.google.protobuf.Descriptors;
import com.google.protobuf.DynamicMessage;
import com.google.protobuf.Message;
+import com.google.protobuf.MessageOrBuilder;
+import com.google.protobuf.StringValue;
+import com.google.protobuf.Timestamp;
+import com.google.protobuf.UInt32Value;
+import com.google.protobuf.UInt64Value;
+import com.google.protobuf.util.Timestamps;
import com.google.protobuf.Value;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
-import org.junit.Test;
-import static org.junit.Assert.*;
-import org.mockito.InOrder;
-import org.mockito.Mockito;
+import org.apache.parquet.hadoop.ParquetWriter;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.io.api.RecordConsumer;
import org.apache.parquet.proto.test.TestProto3;
import org.apache.parquet.proto.test.TestProtobuf;
+import org.junit.Test;
+import org.mockito.InOrder;
+import org.mockito.Mockito;
import org.apache.parquet.proto.test.Trees;
import java.io.IOException;
+import java.time.LocalDate;
+import java.time.LocalTime;
import java.util.List;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
public class ProtoWriteSupportTest {
private <T extends Message> ProtoWriteSupport<T>
createReadConsumerInstance(Class<T> cls, RecordConsumer readConsumerMock) {
@@ -1201,4 +1220,180 @@ public class ProtoWriteSupportTest {
Mockito.verifyNoMoreInteractions(readConsumerMock);
}
+
+ @Test
+ public void testProto3DateTimeMessageUnwrapped() throws Exception {
+ Timestamp timestamp = Timestamps.parse("2021-05-02T15:04:03.748Z");
+ LocalDate date = LocalDate.of(2021, 5, 2);
+ LocalTime time = LocalTime.of(15, 4, 3, 748_000_000);
+
+ RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class);
+ Configuration conf = new Configuration();
+ ProtoWriteSupport.setUnwrapProtoWrappers(conf, true);
+ ProtoWriteSupport<TestProto3.DateTimeMessage> instance =
createReadConsumerInstance(
+ TestProto3.DateTimeMessage.class, readConsumerMock, conf);
+
+ TestProto3.DateTimeMessage.Builder msg =
TestProto3.DateTimeMessage.newBuilder();
+ msg.setTimestamp(timestamp);
+ msg.setDate(com.google.type.Date.newBuilder()
+ .setYear(date.getYear())
+ .setMonth(date.getMonthValue())
+ .setDay(date.getDayOfMonth())
+ );
+ msg.setTime(com.google.type.TimeOfDay.newBuilder()
+ .setHours(time.getHour())
+ .setMinutes(time.getMinute())
+ .setSeconds(time.getSecond())
+ .setNanos(time.getNano())
+ );
+ instance.write(msg.build());
+
+ InOrder inOrder = Mockito.inOrder(readConsumerMock);
+
+ inOrder.verify(readConsumerMock).startMessage();
+ inOrder.verify(readConsumerMock).startField("timestamp", 0);
+ inOrder.verify(readConsumerMock).addLong(Timestamps.toNanos(timestamp));
+ inOrder.verify(readConsumerMock).endField("timestamp", 0);
+ inOrder.verify(readConsumerMock).startField("date", 1);
+ inOrder.verify(readConsumerMock).addInteger((int) date.toEpochDay());
+ inOrder.verify(readConsumerMock).endField("date", 1);
+ inOrder.verify(readConsumerMock).startField("time", 2);
+ inOrder.verify(readConsumerMock).addLong(time.toNanoOfDay());
+ inOrder.verify(readConsumerMock).endField("time", 2);
+ inOrder.verify(readConsumerMock).endMessage();
+ Mockito.verifyNoMoreInteractions(readConsumerMock);
+ }
+
+ @Test
+ public void testProto3DateTimeMessageRoundTrip() throws Exception {
+ Timestamp timestamp = Timestamps.parse("2021-05-02T15:04:03.748Z");
+ LocalDate date = LocalDate.of(2021, 5, 2);
+ LocalTime time = LocalTime.of(15, 4, 3, 748_000_000);
+ com.google.type.Date protoDate = com.google.type.Date.newBuilder()
+ .setYear(date.getYear())
+ .setMonth(date.getMonthValue())
+ .setDay(date.getDayOfMonth())
+ .build();
+ com.google.type.TimeOfDay protoTime =
com.google.type.TimeOfDay.newBuilder()
+ .setHours(time.getHour())
+ .setMinutes(time.getMinute())
+ .setSeconds(time.getSecond())
+ .setNanos(time.getNano())
+ .build();
+
+ TestProto3.DateTimeMessage msg = TestProto3.DateTimeMessage.newBuilder()
+ .setTimestamp(timestamp)
+ .setDate(protoDate)
+ .setTime(protoTime)
+ .build();
+
+ //Write them out and read them back
+ Path tmpFilePath = TestUtils.someTemporaryFilePath();
+ ParquetWriter<MessageOrBuilder> writer =
+ ProtoParquetWriter.<MessageOrBuilder>builder(tmpFilePath)
+ .withMessage(TestProto3.DateTimeMessage.class)
+ .config(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, "true")
+ .build();
+ writer.write(msg);
+ writer.close();
+ List<TestProto3.DateTimeMessage> gotBack =
TestUtils.readMessages(tmpFilePath, TestProto3.DateTimeMessage.class);
+
+ TestProto3.DateTimeMessage gotBackFirst = gotBack.get(0);
+ assertEquals(timestamp, gotBackFirst.getTimestamp());
+ assertEquals(protoDate, gotBackFirst.getDate());
+ assertEquals(protoTime, gotBackFirst.getTime());
+ }
+
+ @Test
+ public void testProto3WrappedMessageUnwrapped() throws Exception {
+ RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class);
+ Configuration conf = new Configuration();
+ ProtoWriteSupport.setUnwrapProtoWrappers(conf, true);
+ ProtoWriteSupport<TestProto3.WrappedMessage> instance =
createReadConsumerInstance(
+ TestProto3.WrappedMessage.class, readConsumerMock, conf);
+
+ TestProto3.WrappedMessage.Builder msg =
TestProto3.WrappedMessage.newBuilder();
+ msg.setWrappedDouble(DoubleValue.of(3.1415));
+
+ instance.write(msg.build());
+
+ InOrder inOrder = Mockito.inOrder(readConsumerMock);
+
+ inOrder.verify(readConsumerMock).startMessage();
+ inOrder.verify(readConsumerMock).startField("wrappedDouble", 0);
+ inOrder.verify(readConsumerMock).addDouble(3.1415);
+ inOrder.verify(readConsumerMock).endField("wrappedDouble", 0);
+ inOrder.verify(readConsumerMock).endMessage();
+ Mockito.verifyNoMoreInteractions(readConsumerMock);
+ }
+
+ @Test
+ public void testProto3WrappedMessageUnwrappedRoundTrip() throws Exception {
+ TestProto3.WrappedMessage.Builder msg =
TestProto3.WrappedMessage.newBuilder();
+ msg.setWrappedDouble(DoubleValue.of(0.577));
+ msg.setWrappedFloat(FloatValue.of(3.1415f));
+ msg.setWrappedInt64(Int64Value.of(1_000_000_000L * 4));
+ msg.setWrappedUInt64(UInt64Value.of(1_000_000_000L * 9));
+ msg.setWrappedInt32(Int32Value.of(1_000_000 * 3));
+ msg.setWrappedUInt32(UInt32Value.of(1_000_000 * 8));
+ msg.setWrappedBool(BoolValue.of(true));
+ msg.setWrappedString(StringValue.of("Good Will Hunting"));
+ msg.setWrappedBytes(BytesValue.of(ByteString.copyFrom("someText",
"UTF-8")));
+
+ //Write them out and read them back
+ Path tmpFilePath = TestUtils.someTemporaryFilePath();
+ ParquetWriter<MessageOrBuilder> writer =
+ ProtoParquetWriter.<MessageOrBuilder>builder(tmpFilePath)
+ .withMessage(TestProto3.WrappedMessage.class)
+ .config(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, "true")
+ .build();
+ writer.write(msg);
+ writer.close();
+ List<TestProto3.WrappedMessage> gotBack =
TestUtils.readMessages(tmpFilePath, TestProto3.WrappedMessage.class);
+
+ TestProto3.WrappedMessage gotBackFirst = gotBack.get(0);
+ assertEquals(0.577, gotBackFirst.getWrappedDouble().getValue(), 1e-5);
+ assertEquals(3.1415f, gotBackFirst.getWrappedFloat().getValue(), 1e-5f);
+ assertEquals(1_000_000_000L * 4,
gotBackFirst.getWrappedInt64().getValue());
+ assertEquals(1_000_000_000L * 9,
gotBackFirst.getWrappedUInt64().getValue());
+ assertEquals(1_000_000 * 3, gotBackFirst.getWrappedInt32().getValue());
+ assertEquals(1_000_000 * 8, gotBackFirst.getWrappedUInt32().getValue());
+ assertEquals(BoolValue.of(true), gotBackFirst.getWrappedBool());
+ assertEquals("Good Will Hunting",
gotBackFirst.getWrappedString().getValue());
+ assertEquals(ByteString.copyFrom("someText", "UTF-8"),
gotBackFirst.getWrappedBytes().getValue());
+ }
+
+ @Test
+ public void testProto3WrappedMessageWithNullsRoundTrip() throws Exception {
+ TestProto3.WrappedMessage.Builder msg =
TestProto3.WrappedMessage.newBuilder();
+ msg.setWrappedFloat(FloatValue.of(3.1415f));
+ msg.setWrappedString(StringValue.of("Good Will Hunting"));
+ msg.setWrappedInt32(Int32Value.of(0));
+
+ //Write them out and read them back
+ Path tmpFilePath = TestUtils.someTemporaryFilePath();
+ ParquetWriter<MessageOrBuilder> writer =
+ ProtoParquetWriter.<MessageOrBuilder>builder(tmpFilePath)
+ .withMessage(TestProto3.WrappedMessage.class)
+ .config(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, "true")
+ .build();
+ writer.write(msg);
+ writer.close();
+ List<TestProto3.WrappedMessage> gotBack =
TestUtils.readMessages(tmpFilePath, TestProto3.WrappedMessage.class);
+
+ TestProto3.WrappedMessage gotBackFirst = gotBack.get(0);
+ assertFalse(gotBackFirst.hasWrappedDouble());
+ assertEquals(3.1415f, gotBackFirst.getWrappedFloat().getValue(), 1e-5f);
+
+ // double-check that nulls are honored
+ assertTrue(gotBackFirst.hasWrappedFloat());
+ assertFalse(gotBackFirst.hasWrappedInt64());
+ assertFalse(gotBackFirst.hasWrappedUInt64());
+ assertTrue(gotBackFirst.hasWrappedInt32());
+ assertFalse(gotBackFirst.hasWrappedUInt32());
+ assertEquals(0, gotBackFirst.getWrappedUInt32().getValue());
+ assertFalse(gotBackFirst.hasWrappedBool());
+ assertEquals("Good Will Hunting",
gotBackFirst.getWrappedString().getValue());
+ assertFalse(gotBackFirst.hasWrappedBytes());
+ }
}
diff --git a/parquet-protobuf/src/test/resources/TestProto3.proto
b/parquet-protobuf/src/test/resources/TestProto3.proto
index fb4da1b0c..c303fd1f5 100644
--- a/parquet-protobuf/src/test/resources/TestProto3.proto
+++ b/parquet-protobuf/src/test/resources/TestProto3.proto
@@ -23,6 +23,11 @@ package TestProto3;
option java_package = "org.apache.parquet.proto.test";
+import "google/protobuf/timestamp.proto";
+import "google/protobuf/wrappers.proto";
+import "google/type/date.proto";
+import "google/type/timeofday.proto";
+
// original Dremel paper structures: Original paper used groups, not internal
// messages but groups were deprecated.
@@ -156,3 +161,21 @@ message FirstCustomClassMessage {
message SecondCustomClassMessage {
string string = 11;
}
+
+message DateTimeMessage {
+ google.protobuf.Timestamp timestamp = 1;
+ google.type.Date date = 2;
+ google.type.TimeOfDay time = 3;
+}
+
+message WrappedMessage {
+ google.protobuf.DoubleValue wrappedDouble = 1;
+ google.protobuf.FloatValue wrappedFloat = 2;
+ google.protobuf.Int64Value wrappedInt64 = 3;
+ google.protobuf.UInt64Value wrappedUInt64 = 4;
+ google.protobuf.Int32Value wrappedInt32 = 5;
+ google.protobuf.UInt32Value wrappedUInt32 = 6;
+ google.protobuf.BoolValue wrappedBool = 7;
+ google.protobuf.StringValue wrappedString = 8;
+ google.protobuf.BytesValue wrappedBytes = 9;
+}