This is an automated email from the ASF dual-hosted git repository.
bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 07e1e02 [BEAM-11861] Add methods to explicitly provide coder for
ParquetIO's Parse and ParseFiles (#14078)
07e1e02 is described below
commit 07e1e02125082d9ec804428f139eb849d79a8ec8
Author: Anant Damle <[email protected]>
AuthorDate: Sat Feb 27 00:26:47 2021 +0800
[BEAM-11861] Add methods to explicitly provide coder for ParquetIO's Parse
and ParseFiles (#14078)
* [BEAM-11861] Provide setCoder, withCoder methods in ParquetIO.Parse<T>
and ParquetIO.ParseFiles<T> to provide explicit coder.
* Added Unit test case
* BEAM-11527 fixed oversight.
---
.../org/apache/beam/sdk/io/parquet/ParquetIO.java | 38 +++++++++++++++++++---
.../apache/beam/sdk/io/parquet/ParquetIOTest.java | 31 ++++++++++++++++++
2 files changed, 65 insertions(+), 4 deletions(-)
diff --git
a/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java
b/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java
index 6adbc59..5c2a19d 100644
---
a/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java
+++
b/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java
@@ -405,12 +405,14 @@ public class ParquetIO {
.withSplit()
.withBeamSchemas(getInferBeamSchema())
.withAvroDataModel(getAvroDataModel())
- .withProjection(getProjectionSchema(), getEncoderSchema()));
+ .withProjection(getProjectionSchema(), getEncoderSchema())
+ .withConfiguration(getConfiguration()));
}
return inputFiles.apply(
readFiles(getSchema())
.withBeamSchemas(getInferBeamSchema())
- .withAvroDataModel(getAvroDataModel()));
+ .withAvroDataModel(getAvroDataModel())
+ .withConfiguration(getConfiguration()));
}
@Override
@@ -428,6 +430,8 @@ public class ParquetIO {
abstract SerializableFunction<GenericRecord, T> getParseFn();
+ abstract @Nullable Coder<T> getCoder();
+
abstract @Nullable SerializableConfiguration getConfiguration();
abstract boolean isSplittable();
@@ -440,6 +444,8 @@ public class ParquetIO {
abstract Builder<T> setParseFn(SerializableFunction<GenericRecord, T>
parseFn);
+ abstract Builder<T> setCoder(Coder<T> coder);
+
abstract Builder<T> setConfiguration(SerializableConfiguration
configuration);
abstract Builder<T> setSplittable(boolean splittable);
@@ -455,6 +461,11 @@ public class ParquetIO {
return from(ValueProvider.StaticValueProvider.of(inputFiles));
}
+ /** Specify the output coder to use for output of the {@code ParseFn}. */
+ public Parse<T> withCoder(Coder<T> coder) {
+ return (coder == null) ? this : toBuilder().setCoder(coder).build();
+ }
+
/** Specify Hadoop configuration for ParquetReader. */
public Parse<T> withConfiguration(Map<String, String> configuration) {
return
toBuilder().setConfiguration(SerializableConfiguration.fromMap(configuration)).build();
@@ -474,6 +485,7 @@ public class ParquetIO {
.apply(
parseFilesGenericRecords(getParseFn())
.toBuilder()
+ .setCoder(getCoder())
.setSplittable(isSplittable())
.build());
}
@@ -486,6 +498,8 @@ public class ParquetIO {
abstract SerializableFunction<GenericRecord, T> getParseFn();
+ abstract @Nullable Coder<T> getCoder();
+
abstract @Nullable SerializableConfiguration getConfiguration();
abstract boolean isSplittable();
@@ -496,6 +510,8 @@ public class ParquetIO {
abstract static class Builder<T> {
abstract Builder<T> setParseFn(SerializableFunction<GenericRecord, T>
parseFn);
+ abstract Builder<T> setCoder(Coder<T> coder);
+
abstract Builder<T> setConfiguration(SerializableConfiguration
configuration);
abstract Builder<T> setSplittable(boolean split);
@@ -503,6 +519,11 @@ public class ParquetIO {
abstract ParseFiles<T> build();
}
+ /** Specify the output coder to use for output of the {@code ParseFn}. */
+ public ParseFiles<T> withCoder(Coder<T> coder) {
+ return (coder == null) ? this : toBuilder().setCoder(coder).build();
+ }
+
/** Specify Hadoop configuration for ParquetReader. */
public ParseFiles<T> withConfiguration(Map<String, String> configuration) {
return
toBuilder().setConfiguration(SerializableConfiguration.fromMap(configuration)).build();
@@ -537,7 +558,7 @@ public class ParquetIO {
/**
* Identifies the {@code Coder} to be used for the output PCollection.
*
- * <p>Returns {@link AvroCoder} if expected output is {@link
GenericRecord}.
+ * <p>throws an exception if expected output is of type {@link
GenericRecord}.
*
* @param coderRegistry the {@link org.apache.beam.sdk.Pipeline}'s
CoderRegistry to identify
* Coder for expected output type of {@link #getParseFn()}
@@ -547,12 +568,17 @@ public class ParquetIO {
throw new IllegalArgumentException("Parse can't be used for reading as
GenericRecord.");
}
+ // Use explicitly provided coder
+ if (getCoder() != null) {
+ return getCoder();
+ }
+
// If not GenericRecord infer it from ParseFn.
try {
return coderRegistry.getCoder(TypeDescriptors.outputOf(getParseFn()));
} catch (CannotProvideCoderException e) {
throw new IllegalArgumentException(
- "Unable to infer coder for output of parseFn. Specify it
explicitly using withCoder().",
+ "Unable to infer coder for output of parseFn. Specify it
explicitly using .withCoder().",
e);
}
}
@@ -618,6 +644,10 @@ public class ParquetIO {
return
toBuilder().setConfiguration(SerializableConfiguration.fromMap(configuration)).build();
}
+ public ReadFiles withConfiguration(SerializableConfiguration
configuration) {
+ return toBuilder().setConfiguration(configuration).build();
+ }
+
@Experimental(Kind.SCHEMAS)
public ReadFiles withBeamSchemas(boolean inferBeamSchema) {
return toBuilder().setInferBeamSchema(inferBeamSchema).build();
diff --git
a/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java
b/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java
index f3406df..301d102 100644
---
a/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java
+++
b/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java
@@ -43,6 +43,8 @@ import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.io.FileIO;
import org.apache.beam.sdk.io.parquet.ParquetIO.GenericRecordPassthroughFn;
import org.apache.beam.sdk.io.range.OffsetRange;
+import org.apache.beam.sdk.schemas.SchemaCoder;
+import org.apache.beam.sdk.schemas.utils.AvroUtils;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
@@ -50,6 +52,7 @@ import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
import org.apache.parquet.hadoop.metadata.BlockMetaData;
import org.junit.Rule;
import org.junit.Test;
@@ -297,6 +300,34 @@ public class ParquetIOTest implements Serializable {
}
@Test
+ public void testReadFilesAsRowForUnknownSchemaFiles() {
+ List<GenericRecord> records = generateGenericRecords(1000);
+ List<Row> expectedRows =
+ records.stream().map(record -> AvroUtils.toBeamRowStrict(record,
null)).collect(toList());
+
+ PCollection<Row> writeThenRead =
+ mainPipeline
+ .apply(Create.of(records).withCoder(AvroCoder.of(SCHEMA)))
+ .apply(
+ FileIO.<GenericRecord>write()
+ .via(ParquetIO.sink(SCHEMA))
+ .to(temporaryFolder.getRoot().getAbsolutePath()))
+ .getPerDestinationOutputFilenames()
+ .apply(Values.create())
+ .apply(FileIO.matchAll())
+ .apply(FileIO.readMatches())
+ .apply(
+ ParquetIO.parseFilesGenericRecords(
+ (SerializableFunction<GenericRecord, Row>)
+ record -> AvroUtils.toBeamRowStrict(record, null))
+
.withCoder(SchemaCoder.of(AvroUtils.toBeamSchema(SCHEMA))));
+
+ PAssert.that(writeThenRead).containsInAnyOrder(expectedRows);
+
+ mainPipeline.run().waitUntilFinish();
+ }
+
+ @Test
@SuppressWarnings({"nullable", "ConstantConditions"} /* forced check. */)
public void testReadFilesUnknownSchemaFilesForGenericRecordThrowException() {
IllegalArgumentException illegalArgumentException =