This is an automated email from the ASF dual-hosted git repository.

bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 07e1e02  [BEAM-11861] Add methods to explicitly provide coder for 
ParquetIO's Parse and ParseFiles (#14078)
07e1e02 is described below

commit 07e1e02125082d9ec804428f139eb849d79a8ec8
Author: Anant Damle <[email protected]>
AuthorDate: Sat Feb 27 00:26:47 2021 +0800

    [BEAM-11861] Add methods to explicitly provide coder for ParquetIO's Parse 
and ParseFiles (#14078)
    
    * [BEAM-11861] Provide setCoder, withCoder methods in ParquetIO.Parse<T> 
and ParquetIO.ParseFiles<T> to provide explicit coder.
    
    * Added Unit test case
    
    * BEAM-11527 fixed oversight.
---
 .../org/apache/beam/sdk/io/parquet/ParquetIO.java  | 38 +++++++++++++++++++---
 .../apache/beam/sdk/io/parquet/ParquetIOTest.java  | 31 ++++++++++++++++++
 2 files changed, 65 insertions(+), 4 deletions(-)

diff --git 
a/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java
 
b/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java
index 6adbc59..5c2a19d 100644
--- 
a/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java
+++ 
b/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java
@@ -405,12 +405,14 @@ public class ParquetIO {
                 .withSplit()
                 .withBeamSchemas(getInferBeamSchema())
                 .withAvroDataModel(getAvroDataModel())
-                .withProjection(getProjectionSchema(), getEncoderSchema()));
+                .withProjection(getProjectionSchema(), getEncoderSchema())
+                .withConfiguration(getConfiguration()));
       }
       return inputFiles.apply(
           readFiles(getSchema())
               .withBeamSchemas(getInferBeamSchema())
-              .withAvroDataModel(getAvroDataModel()));
+              .withAvroDataModel(getAvroDataModel())
+              .withConfiguration(getConfiguration()));
     }
 
     @Override
@@ -428,6 +430,8 @@ public class ParquetIO {
 
     abstract SerializableFunction<GenericRecord, T> getParseFn();
 
+    abstract @Nullable Coder<T> getCoder();
+
     abstract @Nullable SerializableConfiguration getConfiguration();
 
     abstract boolean isSplittable();
@@ -440,6 +444,8 @@ public class ParquetIO {
 
       abstract Builder<T> setParseFn(SerializableFunction<GenericRecord, T> 
parseFn);
 
+      abstract Builder<T> setCoder(Coder<T> coder);
+
       abstract Builder<T> setConfiguration(SerializableConfiguration 
configuration);
 
       abstract Builder<T> setSplittable(boolean splittable);
@@ -455,6 +461,11 @@ public class ParquetIO {
       return from(ValueProvider.StaticValueProvider.of(inputFiles));
     }
 
+    /** Specify the output coder to use for output of the {@code ParseFn}. */
+    public Parse<T> withCoder(Coder<T> coder) {
+      return (coder == null) ? this : toBuilder().setCoder(coder).build();
+    }
+
     /** Specify Hadoop configuration for ParquetReader. */
     public Parse<T> withConfiguration(Map<String, String> configuration) {
       return 
toBuilder().setConfiguration(SerializableConfiguration.fromMap(configuration)).build();
@@ -474,6 +485,7 @@ public class ParquetIO {
           .apply(
               parseFilesGenericRecords(getParseFn())
                   .toBuilder()
+                  .setCoder(getCoder())
                   .setSplittable(isSplittable())
                   .build());
     }
@@ -486,6 +498,8 @@ public class ParquetIO {
 
     abstract SerializableFunction<GenericRecord, T> getParseFn();
 
+    abstract @Nullable Coder<T> getCoder();
+
     abstract @Nullable SerializableConfiguration getConfiguration();
 
     abstract boolean isSplittable();
@@ -496,6 +510,8 @@ public class ParquetIO {
     abstract static class Builder<T> {
       abstract Builder<T> setParseFn(SerializableFunction<GenericRecord, T> 
parseFn);
 
+      abstract Builder<T> setCoder(Coder<T> coder);
+
       abstract Builder<T> setConfiguration(SerializableConfiguration 
configuration);
 
       abstract Builder<T> setSplittable(boolean split);
@@ -503,6 +519,11 @@ public class ParquetIO {
       abstract ParseFiles<T> build();
     }
 
+    /** Specify the output coder to use for output of the {@code ParseFn}. */
+    public ParseFiles<T> withCoder(Coder<T> coder) {
+      return (coder == null) ? this : toBuilder().setCoder(coder).build();
+    }
+
     /** Specify Hadoop configuration for ParquetReader. */
     public ParseFiles<T> withConfiguration(Map<String, String> configuration) {
       return 
toBuilder().setConfiguration(SerializableConfiguration.fromMap(configuration)).build();
@@ -537,7 +558,7 @@ public class ParquetIO {
     /**
      * Identifies the {@code Coder} to be used for the output PCollection.
      *
-     * <p>Returns {@link AvroCoder} if expected output is {@link 
GenericRecord}.
+     * <p>throws an exception if expected output is of type {@link 
GenericRecord}.
      *
      * @param coderRegistry the {@link org.apache.beam.sdk.Pipeline}'s 
CoderRegistry to identify
      *     Coder for expected output type of {@link #getParseFn()}
@@ -547,12 +568,17 @@ public class ParquetIO {
         throw new IllegalArgumentException("Parse can't be used for reading as 
GenericRecord.");
       }
 
+      // Use explicitly provided coder
+      if (getCoder() != null) {
+        return getCoder();
+      }
+
       // If not GenericRecord infer it from ParseFn.
       try {
         return coderRegistry.getCoder(TypeDescriptors.outputOf(getParseFn()));
       } catch (CannotProvideCoderException e) {
         throw new IllegalArgumentException(
-            "Unable to infer coder for output of parseFn. Specify it 
explicitly using withCoder().",
+            "Unable to infer coder for output of parseFn. Specify it 
explicitly using .withCoder().",
             e);
       }
     }
@@ -618,6 +644,10 @@ public class ParquetIO {
       return 
toBuilder().setConfiguration(SerializableConfiguration.fromMap(configuration)).build();
     }
 
+    public ReadFiles withConfiguration(SerializableConfiguration 
configuration) {
+      return toBuilder().setConfiguration(configuration).build();
+    }
+
     @Experimental(Kind.SCHEMAS)
     public ReadFiles withBeamSchemas(boolean inferBeamSchema) {
       return toBuilder().setInferBeamSchema(inferBeamSchema).build();
diff --git 
a/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java
 
b/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java
index f3406df..301d102 100644
--- 
a/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java
+++ 
b/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java
@@ -43,6 +43,8 @@ import org.apache.beam.sdk.coders.AvroCoder;
 import org.apache.beam.sdk.io.FileIO;
 import org.apache.beam.sdk.io.parquet.ParquetIO.GenericRecordPassthroughFn;
 import org.apache.beam.sdk.io.range.OffsetRange;
+import org.apache.beam.sdk.schemas.SchemaCoder;
+import org.apache.beam.sdk.schemas.utils.AvroUtils;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
@@ -50,6 +52,7 @@ import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.Values;
 import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
 import org.apache.parquet.hadoop.metadata.BlockMetaData;
 import org.junit.Rule;
 import org.junit.Test;
@@ -297,6 +300,34 @@ public class ParquetIOTest implements Serializable {
   }
 
   @Test
+  public void testReadFilesAsRowForUnknownSchemaFiles() {
+    List<GenericRecord> records = generateGenericRecords(1000);
+    List<Row> expectedRows =
+        records.stream().map(record -> AvroUtils.toBeamRowStrict(record, 
null)).collect(toList());
+
+    PCollection<Row> writeThenRead =
+        mainPipeline
+            .apply(Create.of(records).withCoder(AvroCoder.of(SCHEMA)))
+            .apply(
+                FileIO.<GenericRecord>write()
+                    .via(ParquetIO.sink(SCHEMA))
+                    .to(temporaryFolder.getRoot().getAbsolutePath()))
+            .getPerDestinationOutputFilenames()
+            .apply(Values.create())
+            .apply(FileIO.matchAll())
+            .apply(FileIO.readMatches())
+            .apply(
+                ParquetIO.parseFilesGenericRecords(
+                        (SerializableFunction<GenericRecord, Row>)
+                            record -> AvroUtils.toBeamRowStrict(record, null))
+                    
.withCoder(SchemaCoder.of(AvroUtils.toBeamSchema(SCHEMA))));
+
+    PAssert.that(writeThenRead).containsInAnyOrder(expectedRows);
+
+    mainPipeline.run().waitUntilFinish();
+  }
+
+  @Test
   @SuppressWarnings({"nullable", "ConstantConditions"} /* forced check. */)
   public void testReadFilesUnknownSchemaFilesForGenericRecordThrowException() {
     IllegalArgumentException illegalArgumentException =

Reply via email to