Repository: beam Updated Branches: refs/heads/master bea4f5aec -> 92d1a6635
add TensorFlow TFRecordIO Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/68d42f9b Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/68d42f9b Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/68d42f9b Branch: refs/heads/master Commit: 68d42f9b941dab420947a8aaa616070b1f3fb5f8 Parents: bea4f5a Author: Neville Li <[email protected]> Authored: Tue Feb 21 15:51:18 2017 -0500 Committer: Chamikara Jayalath <[email protected]> Committed: Tue Mar 21 11:23:02 2017 -0700 ---------------------------------------------------------------------- .../apache/beam/sdk/io/CompressedSource.java | 13 +- .../java/org/apache/beam/sdk/io/TFRecordIO.java | 905 +++++++++++++++++++ .../java/org/apache/beam/sdk/io/TextIO.java | 2 +- .../org/apache/beam/sdk/io/TFRecordIOTest.java | 368 ++++++++ .../java/org/apache/beam/sdk/io/TextIOTest.java | 2 +- 5 files changed, 1285 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/68d42f9b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java index 6de22f9..ecd0fd9 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java @@ -334,8 +334,14 @@ public class CompressedSource<T> extends FileBasedSource<T> { super(filePatternOrSpec, minBundleSize, startOffset, endOffset); this.sourceDelegate = sourceDelegate; this.channelFactory = channelFactory; + boolean splittable = false; + try { + splittable = isSplittable(); + } catch (Exception e) { + throw new RuntimeException("Failed to determine if the source is splittable", e); + } checkArgument( - isSplittable() || startOffset == 0, + splittable || startOffset == 0, "CompressedSources must start reading at offset 0. Requested offset: " + startOffset); } @@ -366,11 +372,12 @@ public class CompressedSource<T> extends FileBasedSource<T> { * from the requested file name that the file is not compressed. */ @Override - protected final boolean isSplittable() { + protected final boolean isSplittable() throws Exception { if (channelFactory instanceof FileNameBasedDecompressingChannelFactory) { FileNameBasedDecompressingChannelFactory fileNameBasedChannelFactory = (FileNameBasedDecompressingChannelFactory) channelFactory; - return !fileNameBasedChannelFactory.isCompressed(getFileOrPatternSpec()); + return !fileNameBasedChannelFactory.isCompressed(getFileOrPatternSpec()) + && sourceDelegate.isSplittable(); } return false; } http://git-wip-us.apache.org/repos/asf/beam/blob/68d42f9b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java new file mode 100644 index 0000000..243506c --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java @@ -0,0 +1,905 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.hash.HashFunction; +import com.google.common.hash.Hashing; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.util.NoSuchElementException; +import java.util.regex.Pattern; + +import javax.annotation.Nullable; + +import org.apache.beam.sdk.coders.ByteArrayCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.io.Read.Bounded; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.util.IOChannelUtils; +import org.apache.beam.sdk.util.MimeTypes; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PDone; + +/** + * {@link PTransform}s for reading and writing TensorFlow TFRecord files. + */ +public class TFRecordIO { + /** The default coder, which returns each record of the input file as a byte array. */ + public static final Coder<byte[]> DEFAULT_BYTE_ARRAY_CODER = ByteArrayCoder.of(); + + /** + * A {@link PTransform} that reads from a TFRecord file (or multiple TFRecord + * files matching a pattern) and returns a {@link PCollection} containing + * the decoding of each of the records of the TFRecord file(s) as a byte array. + */ + public static class Read { + + /** + * Returns a transform for reading TFRecord files that reads from the file(s) + * with the given filename or filename pattern. This can be a local path (if running locally), + * or a Google Cloud Storage filename or filename pattern of the form + * {@code "gs://<bucket>/<filepath>"} (if running locally or via the Google Cloud Dataflow + * service). Standard <a href="http://docs.oracle.com/javase/tutorial/essential/io/find.html" + * >Java Filesystem glob patterns</a> ("*", "?", "[..]") are supported. + */ + public static Bound from(String filepattern) { + return new Bound().from(filepattern); + } + + /** + * Same as {@code from(filepattern)}, but accepting a {@link ValueProvider}. + */ + public static Bound from(ValueProvider<String> filepattern) { + return new Bound().from(filepattern); + } + /** + * Returns a transform for reading TFRecord files that has GCS path validation on + * pipeline creation disabled. + * + * <p>This can be useful in the case where the GCS input does not + * exist at the pipeline creation time, but is expected to be + * available at execution time. + */ + public static Bound withoutValidation() { + return new Bound().withoutValidation(); + } + + /** + * Returns a transform for reading TFRecord files that decompresses all input files + * using the specified compression type. + * + * <p>If no compression type is specified, the default is + * {@link TFRecordIO.CompressionType#AUTO}. + * In this mode, the compression type of the file is determined by its extension + * (e.g., {@code *.gz} is gzipped, {@code *.zlib} is zlib compressed, and all other + * extensions are uncompressed). + */ + public static Bound withCompressionType(TFRecordIO.CompressionType compressionType) { + return new Bound().withCompressionType(compressionType); + } + + /** + * A {@link PTransform} that reads from one or more TFRecord files and returns a bounded + * {@link PCollection} containing one element for each record of the input files. + */ + public static class Bound extends PTransform<PBegin, PCollection<byte[]>> { + /** The filepattern to read from. */ + @Nullable private final ValueProvider<String> filepattern; + + /** An option to indicate if input validation is desired. Default is true. */ + private final boolean validate; + + /** Option to indicate the input source's compression type. Default is AUTO. */ + private final TFRecordIO.CompressionType compressionType; + + private Bound() { + this(null, null, true, TFRecordIO.CompressionType.AUTO); + } + + private Bound( + @Nullable String name, + @Nullable ValueProvider<String> filepattern, + boolean validate, + TFRecordIO.CompressionType compressionType) { + super(name); + this.filepattern = filepattern; + this.validate = validate; + this.compressionType = compressionType; + } + + /** + * Returns a new transform for reading from TFRecord files that's like this one but that + * reads from the file(s) with the given name or pattern. See {@link TFRecordIO.Read#from} + * for a description of filepatterns. + * + * <p>Does not modify this object. + + */ + public Bound from(String filepattern) { + checkNotNull(filepattern, "Filepattern cannot be empty."); + return new Bound(name, StaticValueProvider.of(filepattern), validate, compressionType); + } + + /** + * Same as {@code from(filepattern)}, but accepting a {@link ValueProvider}. + */ + public Bound from(ValueProvider<String> filepattern) { + checkNotNull(filepattern, "Filepattern cannot be empty."); + return new Bound(name, filepattern, validate, compressionType); + } + + /** + * Returns a new transform for reading from TFRecord files that's like this one but + * that has GCS path validation on pipeline creation disabled. + * + * <p>This can be useful in the case where the GCS input does not + * exist at the pipeline creation time, but is expected to be + * available at execution time. + * + * <p>Does not modify this object. + */ + public Bound withoutValidation() { + return new Bound(name, filepattern, false, compressionType); + } + + /** + * Returns a new transform for reading from TFRecord files that's like this one but + * reads from input sources using the specified compression type. + * + * <p>If no compression type is specified, the default is + * {@link TFRecordIO.CompressionType#AUTO}. + * See {@link TFRecordIO.Read#withCompressionType} for more details. + * + * <p>Does not modify this object. + */ + public Bound withCompressionType(TFRecordIO.CompressionType compressionType) { + return new Bound(name, filepattern, validate, compressionType); + } + + @Override + public PCollection<byte[]> expand(PBegin input) { + if (filepattern == null) { + throw new IllegalStateException( + "Need to set the filepattern of a TFRecordIO.Read transform"); + } + + if (validate) { + checkState(filepattern.isAccessible(), "Cannot validate with a RVP."); + try { + checkState( + !IOChannelUtils.getFactory(filepattern.get()).match(filepattern.get()).isEmpty(), + "Unable to find any files matching %s", + filepattern); + } catch (IOException e) { + throw new IllegalStateException( + String.format("Failed to validate %s", filepattern.get()), e); + } + } + + final Bounded<byte[]> read = org.apache.beam.sdk.io.Read.from(getSource()); + PCollection<byte[]> pcol = input.getPipeline().apply("Read", read); + // Honor the default output coder that would have been used by this PTransform. + pcol.setCoder(getDefaultOutputCoder()); + return pcol; + } + + // Helper to create a source specific to the requested compression type. + protected FileBasedSource<byte[]> getSource() { + switch (compressionType) { + case NONE: + return new TFRecordSource(filepattern); + case AUTO: + return CompressedSource.from(new TFRecordSource(filepattern)); + case GZIP: + return + CompressedSource.from(new TFRecordSource(filepattern)) + .withDecompression(CompressedSource.CompressionMode.GZIP); + case ZLIB: + return + CompressedSource.from(new TFRecordSource(filepattern)) + .withDecompression(CompressedSource.CompressionMode.DEFLATE); + default: + throw new IllegalArgumentException("Unknown compression type: " + compressionType); + } + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + + String filepatternDisplay = filepattern.isAccessible() + ? filepattern.get() : filepattern.toString(); + builder + .add(DisplayData.item("compressionType", compressionType.toString()) + .withLabel("Compression Type")) + .addIfNotDefault(DisplayData.item("validation", validate) + .withLabel("Validation Enabled"), true) + .addIfNotNull(DisplayData.item("filePattern", filepatternDisplay) + .withLabel("File Pattern")); + } + + @Override + protected Coder<byte[]> getDefaultOutputCoder() { + return DEFAULT_BYTE_ARRAY_CODER; + } + + public String getFilepattern() { + return filepattern.get(); + } + + public boolean needsValidation() { + return validate; + } + + public TFRecordIO.CompressionType getCompressionType() { + return compressionType; + } + } + + /** Disallow construction of utility classes. */ + private Read() {} + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A {@link PTransform} that writes a {@link PCollection} to TFRecord file (or + * multiple TFRecord files matching a sharding pattern), with each + * element of the input collection encoded into its own record. + */ + public static class Write { + + /** + * Returns a transform for writing to TFRecord files that writes to the file(s) + * with the given prefix. This can be a local filename + * (if running locally), or a Google Cloud Storage filename of + * the form {@code "gs://<bucket>/<filepath>"} + * (if running locally or via the Google Cloud Dataflow service). + * + * <p>The files written will begin with this prefix, followed by + * a shard identifier (see {@link TFRecordIO.Write.Bound#withNumShards(int)}, and end + * in a common extension, if given by {@link TFRecordIO.Write.Bound#withSuffix(String)}. + */ + public static Bound to(String prefix) { + return new Bound().to(prefix); + } + + /** + * Like {@link #to(String)}, but with a {@link ValueProvider}. + */ + public static Bound to(ValueProvider<String> prefix) { + return new Bound().to(prefix); + } + + /** + * Returns a transform for writing to TFRecord files that appends the specified suffix + * to the created files. + */ + public static Bound withSuffix(String nameExtension) { + return new Bound().withSuffix(nameExtension); + } + + /** + * Returns a transform for writing to TFRecord files that uses the provided shard count. + * + * <p>Constraining the number of shards is likely to reduce + * the performance of a pipeline. Setting this value is not recommended + * unless you require a specific number of output files. + * + * @param numShards the number of shards to use, or 0 to let the system + * decide. + */ + public static Bound withNumShards(int numShards) { + return new Bound().withNumShards(numShards); + } + + /** + * Returns a transform for writing to TFRecord files that uses the given shard name + * template. + * + * <p>See {@link ShardNameTemplate} for a description of shard templates. + */ + public static Bound withShardNameTemplate(String shardTemplate) { + return new Bound().withShardNameTemplate(shardTemplate); + } + + /** + * Returns a transform for writing to TFRecord files that forces a single file as + * output. + */ + public static Bound withoutSharding() { + return new Bound().withoutSharding(); + } + + /** + * Returns a transform for writing to text files that has GCS path validation on + * pipeline creation disabled. + * + * <p>This can be useful in the case where the GCS output location does + * not exist at the pipeline creation time, but is expected to be available + * at execution time. + */ + public static Bound withoutValidation() { + return new Bound().withoutValidation(); + } + + /** + * Returns a transform for writing to TFRecord files like this one but writes to output files + * using the specified compression type. + * + * <p>If no compression type is specified, the default is + * {@link TFRecordIO.CompressionType#NONE}. + * See {@link TFRecordIO.Read#withCompressionType} for more details. + */ + public static Bound withCompressionType(CompressionType compressionType) { + return new Bound().withCompressionType(compressionType); + } + + /** + * A PTransform that writes a bounded PCollection to a TFRecord file (or + * multiple TFRecord files matching a sharding pattern), with each + * PCollection element being encoded into its own record. + */ + public static class Bound extends PTransform<PCollection<byte[]>, PDone> { + private static final String DEFAULT_SHARD_TEMPLATE = ShardNameTemplate.INDEX_OF_MAX; + + /** The prefix of each file written, combined with suffix and shardTemplate. */ + private final ValueProvider<String> filenamePrefix; + /** The suffix of each file written, combined with prefix and shardTemplate. */ + private final String filenameSuffix; + + /** Requested number of shards. 0 for automatic. */ + private final int numShards; + + /** The shard template of each file written, combined with prefix and suffix. */ + private final String shardTemplate; + + /** An option to indicate if output validation is desired. Default is true. */ + private final boolean validate; + + /** Option to indicate the output sink's compression type. Default is NONE. */ + private final TFRecordIO.CompressionType compressionType; + + private Bound() { + this(null, null, "", 0, DEFAULT_SHARD_TEMPLATE, true, TFRecordIO.CompressionType.NONE); + } + + private Bound(String name, ValueProvider<String> filenamePrefix, String filenameSuffix, + int numShards, String shardTemplate, boolean validate, + CompressionType compressionType) { + super(name); + this.filenamePrefix = filenamePrefix; + this.filenameSuffix = filenameSuffix; + this.numShards = numShards; + this.shardTemplate = shardTemplate; + this.validate = validate; + this.compressionType = compressionType; + } + + /** + * Returns a transform for writing to TFRecord files that's like this one but + * that writes to the file(s) with the given filename prefix. + * + * <p>See {@link TFRecordIO.Write#to(String) Write.to(String)} for more information. + * + * <p>Does not modify this object. + */ + public Bound to(String filenamePrefix) { + validateOutputComponent(filenamePrefix); + return new Bound(name, StaticValueProvider.of(filenamePrefix), filenameSuffix, numShards, + shardTemplate, validate, compressionType); + } + + /** + * Like {@link #to(String)}, but with a {@link ValueProvider}. + */ + public Bound to(ValueProvider<String> filenamePrefix) { + return new Bound(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, validate, + compressionType); + } + + /** + * Returns a transform for writing to TFRecord files that that's like this one but + * that writes to the file(s) with the given filename suffix. + * + * <p>Does not modify this object. + * + * @see ShardNameTemplate + */ + public Bound withSuffix(String nameExtension) { + validateOutputComponent(nameExtension); + return new Bound(name, filenamePrefix, nameExtension, numShards, shardTemplate, validate, + compressionType); + } + + /** + * Returns a transform for writing to TFRecord files that's like this one but + * that uses the provided shard count. + * + * <p>Constraining the number of shards is likely to reduce + * the performance of a pipeline. Setting this value is not recommended + * unless you require a specific number of output files. + * + * <p>Does not modify this object. + * + * @param numShards the number of shards to use, or 0 to let the system + * decide. + * @see ShardNameTemplate + */ + public Bound withNumShards(int numShards) { + checkArgument(numShards >= 0); + return new Bound(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, validate, + compressionType); + } + + /** + * Returns a transform for writing to TFRecord files that's like this one but + * that uses the given shard name template. + * + * <p>Does not modify this object. + * + * @see ShardNameTemplate + */ + public Bound withShardNameTemplate(String shardTemplate) { + return new Bound(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, validate, + compressionType); + } + + /** + * Returns a transform for writing to TFRecord files that's like this one but + * that forces a single file as output. + * + * <p>Constraining the number of shards is likely to reduce + * the performance of a pipeline. Using this setting is not recommended + * unless you truly require a single output file. + * + * <p>This is a shortcut for + * {@code .withNumShards(1).withShardNameTemplate("")} + * + * <p>Does not modify this object. + */ + public Bound withoutSharding() { + return new Bound(name, filenamePrefix, filenameSuffix, 1, "", + validate, compressionType); + } + + /** + * Returns a transform for writing to TFRecord files that's like this one but + * that has GCS output path validation on pipeline creation disabled. + * + * <p>This can be useful in the case where the GCS output location does + * not exist at the pipeline creation time, but is expected to be + * available at execution time. + * + * <p>Does not modify this object. + */ + public Bound withoutValidation() { + return new Bound(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, false, + compressionType); + } + + /** + * Returns a transform for writing to TFRecord files like this one but writes to output files + * using the specified compression type. + * + * <p>If no compression type is specified, the default is + * {@link TFRecordIO.CompressionType#NONE}. + * See {@link TFRecordIO.Read#withCompressionType} for more details. + * + * <p>Does not modify this object. + */ + public Bound withCompressionType(CompressionType compressionType) { + return new Bound(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, validate, + compressionType); + } + + @Override + public PDone expand(PCollection<byte[]> input) { + if (filenamePrefix == null) { + throw new IllegalStateException( + "need to set the filename prefix of a TFRecordIO.Write transform"); + } + org.apache.beam.sdk.io.Write<byte[]> write = + org.apache.beam.sdk.io.Write.to( + new TFRecordSink(filenamePrefix, filenameSuffix, shardTemplate, compressionType)); + if (getNumShards() > 0) { + write = write.withNumShards(getNumShards()); + } + return input.apply("Write", write); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + + String prefixString = filenamePrefix.isAccessible() + ? filenamePrefix.get() : filenamePrefix.toString(); + builder + .addIfNotNull(DisplayData.item("filePrefix", prefixString) + .withLabel("Output File Prefix")) + .addIfNotDefault(DisplayData.item("fileSuffix", filenameSuffix) + .withLabel("Output File Suffix"), "") + .addIfNotDefault(DisplayData.item("shardNameTemplate", shardTemplate) + .withLabel("Output Shard Name Template"), + DEFAULT_SHARD_TEMPLATE) + .addIfNotDefault(DisplayData.item("validation", validate) + .withLabel("Validation Enabled"), true) + .addIfNotDefault(DisplayData.item("numShards", numShards) + .withLabel("Maximum Output Shards"), 0) + .add(DisplayData + .item("compressionType", compressionType.toString()) + .withLabel("Compression Type")); + } + + /** + * Returns the current shard name template string. + */ + public String getShardNameTemplate() { + return shardTemplate; + } + + @Override + protected Coder<Void> getDefaultOutputCoder() { + return VoidCoder.of(); + } + + public String getFilenamePrefix() { + return filenamePrefix.get(); + } + + public String getShardTemplate() { + return shardTemplate; + } + + public int getNumShards() { + return numShards; + } + + public String getFilenameSuffix() { + return filenameSuffix; + } + + public boolean needsValidation() { + return validate; + } + } + } + + /** + * Possible TFRecord file compression types. + */ + public enum CompressionType { + /** + * Automatically determine the compression type based on filename extension. + */ + AUTO(""), + /** + * Uncompressed. + */ + NONE(""), + /** + * GZipped. + */ + GZIP(".gz"), + /** + * ZLIB compressed. + */ + ZLIB(".zlib"); + + private String filenameSuffix; + + CompressionType(String suffix) { + this.filenameSuffix = suffix; + } + + /** + * Determine if a given filename matches a compression type based on its extension. + * @param filename the filename to match + * @return true iff the filename ends with the compression type's known extension. + */ + public boolean matches(String filename) { + return filename.toLowerCase().endsWith(filenameSuffix.toLowerCase()); + } + } + + // Pattern which matches old-style shard output patterns, which are now + // disallowed. + private static final Pattern SHARD_OUTPUT_PATTERN = Pattern.compile("@([0-9]+|\\*)"); + + private static void validateOutputComponent(String partialFilePattern) { + checkArgument( + !SHARD_OUTPUT_PATTERN.matcher(partialFilePattern).find(), + "Output name components are not allowed to contain @* or @N patterns: " + + partialFilePattern); + } + + ////////////////////////////////////////////////////////////////////////////// + + /** Disable construction of utility class. */ + private TFRecordIO() {} + + /** + * A {@link FileBasedSource} which can decode records in TFRecord files. + */ + @VisibleForTesting + static class TFRecordSource extends FileBasedSource<byte[]> { + @VisibleForTesting + TFRecordSource(String fileSpec) { + super(fileSpec, 1L); + } + + @VisibleForTesting + TFRecordSource(ValueProvider<String> fileSpec) { + super(fileSpec, Long.MAX_VALUE); + } + + private TFRecordSource(String fileName, long start, long end) { + super(fileName, Long.MAX_VALUE, start, end); + } + + @Override + protected FileBasedSource<byte[]> createForSubrangeOfFile( + String fileName, + long start, + long end) { + checkArgument(start == 0, "TFRecordSource is not splittable"); + return new TFRecordSource(fileName, start, end); + } + + @Override + protected FileBasedReader<byte[]> createSingleFileReader(PipelineOptions options) { + return new TFRecordReader(this); + } + + @Override + public Coder<byte[]> getDefaultOutputCoder() { + return DEFAULT_BYTE_ARRAY_CODER; + } + + @Override + protected boolean isSplittable() throws Exception { + // TFRecord files are not splittable + return false; + } + + /** + * A {@link org.apache.beam.sdk.io.FileBasedSource.FileBasedReader FileBasedReader} + * which can decode records in TFRecord files. + * + * <p>See {@link TFRecordIO.TFRecordSource} for further details. + */ + @VisibleForTesting + static class TFRecordReader extends FileBasedReader<byte[]> { + private long startOfRecord; + private volatile long startOfNextRecord; + private volatile boolean elementIsPresent; + private byte[] currentValue; + private ReadableByteChannel inChannel; + private TFRecordCodec codec; + + private TFRecordReader(TFRecordSource source) { + super(source); + } + + @Override + protected long getCurrentOffset() throws NoSuchElementException { + if (!elementIsPresent) { + throw new NoSuchElementException(); + } + return startOfRecord; + } + + @Override + public byte[] getCurrent() throws NoSuchElementException { + if (!elementIsPresent) { + throw new NoSuchElementException(); + } + return currentValue; + } + + @Override + protected void startReading(ReadableByteChannel channel) throws IOException { + this.inChannel = channel; + this.codec = new TFRecordCodec(); + } + + @Override + protected boolean readNextRecord() throws IOException { + startOfRecord = startOfNextRecord; + currentValue = codec.read(inChannel); + if (currentValue != null) { + elementIsPresent = true; + startOfNextRecord = startOfRecord + codec.recordLength(currentValue); + return true; + } else { + elementIsPresent = false; + return false; + } + } + } + } + + /** + * A {@link FileBasedSink} for TFRecord files. Produces TFRecord files. + */ + @VisibleForTesting + static class TFRecordSink extends FileBasedSink<byte[]> { + @VisibleForTesting + TFRecordSink(ValueProvider<String> baseOutputFilename, + String extension, + String fileNameTemplate, + TFRecordIO.CompressionType compressionType) { + super(baseOutputFilename, extension, fileNameTemplate, + writableByteChannelFactory(compressionType)); + } + + @Override + public FileBasedWriteOperation<byte[]> createWriteOperation(PipelineOptions options) { + return new TFRecordWriteOperation(this); + } + + private static WritableByteChannelFactory writableByteChannelFactory( + TFRecordIO.CompressionType compressionType) { + switch (compressionType) { + case AUTO: + throw new IllegalArgumentException("Unsupported compression type AUTO"); + case NONE: + return CompressionType.UNCOMPRESSED; + case GZIP: + return CompressionType.GZIP; + case ZLIB: + return CompressionType.DEFLATE; + } + return CompressionType.UNCOMPRESSED; + } + + /** + * A {@link org.apache.beam.sdk.io.FileBasedSink.FileBasedWriteOperation + * FileBasedWriteOperation} for TFRecord files. + */ + private static class TFRecordWriteOperation extends FileBasedWriteOperation<byte[]> { + private TFRecordWriteOperation(TFRecordSink sink) { + super(sink); + } + + @Override + public FileBasedWriter<byte[]> createWriter(PipelineOptions options) throws Exception { + return new TFRecordWriter(this); + } + } + + /** + * A {@link org.apache.beam.sdk.io.FileBasedSink.FileBasedWriter FileBasedWriter} + * for TFRecord files. + */ + private static class TFRecordWriter extends FileBasedWriter<byte[]> { + private WritableByteChannel outChannel; + private TFRecordCodec codec; + + private TFRecordWriter(FileBasedWriteOperation<byte[]> writeOperation) { + super(writeOperation); + this.mimeType = MimeTypes.BINARY; + } + + @Override + protected void prepareWrite(WritableByteChannel channel) throws Exception { + this.outChannel = channel; + this.codec = new TFRecordCodec(); + } + + @Override + public void write(byte[] value) throws Exception { + codec.write(outChannel, value); + } + } + } + + ////////////////////////////////////////////////////////////////////////////// + + /** + * Codec for TFRecords file format. + * See https://www.tensorflow.org/api_guides/python/python_io#TFRecords_Format_Details + */ + private static class TFRecordCodec { + private static final int HEADER_LEN = (Long.SIZE + Integer.SIZE) / Byte.SIZE; + private static final int FOOTER_LEN = Integer.SIZE / Byte.SIZE; + private static HashFunction crc32c = Hashing.crc32c(); + + private ByteBuffer header = ByteBuffer.allocate(HEADER_LEN).order(ByteOrder.LITTLE_ENDIAN); + private ByteBuffer footer = ByteBuffer.allocate(FOOTER_LEN).order(ByteOrder.LITTLE_ENDIAN); + + private int mask(int crc) { + return ((crc >>> 15) | (crc << 17)) + 0xa282ead8; + } + + private int hashLong(long x) { + return mask(crc32c.hashLong(x).asInt()); + } + + private int hashBytes(byte[] x) { + return mask(crc32c.hashBytes(x).asInt()); + } + + public int recordLength(byte[] data) { + return HEADER_LEN + data.length + FOOTER_LEN; + } + + public byte[] read(ReadableByteChannel inChannel) throws IOException { + header.clear(); + int headerBytes = inChannel.read(header); + if (headerBytes <= 0) { + return null; + } + checkState( + headerBytes == HEADER_LEN, + "Not a valid TFRecord. Fewer than 12 bytes."); + header.rewind(); + long length = header.getLong(); + int maskedCrc32OfLength = header.getInt(); + checkState( + hashLong(length) == maskedCrc32OfLength, + "Mismatch of length mask"); + + ByteBuffer data = ByteBuffer.allocate((int) length); + checkState(inChannel.read(data) == length, "Invalid data"); + + footer.clear(); + inChannel.read(footer); + footer.rewind(); + int maskedCrc32OfData = footer.getInt(); + + checkState( + hashBytes(data.array()) == maskedCrc32OfData, + "Mismatch of data mask"); + return data.array(); + } + + public void write(WritableByteChannel outChannel, byte[] data) throws IOException { + int maskedCrc32OfLength = hashLong(data.length); + int maskedCrc32OfData = hashBytes(data); + + header.clear(); + header.putLong(data.length).putInt(maskedCrc32OfLength); + header.rewind(); + outChannel.write(header); + + outChannel.write(ByteBuffer.wrap(data)); + + footer.clear(); + footer.putInt(maskedCrc32OfData); + footer.rewind(); + outChannel.write(footer); + } + } + +} http://git-wip-us.apache.org/repos/asf/beam/blob/68d42f9b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java index fe8d0fd..f8943a5 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java @@ -682,7 +682,7 @@ public class TextIO { .addIfNotNull(DisplayData.item("filePrefix", prefixString) .withLabel("Output File Prefix")) .addIfNotDefault(DisplayData.item("fileSuffix", filenameSuffix) - .withLabel("Output Fix Suffix"), "") + .withLabel("Output File Suffix"), "") .addIfNotDefault(DisplayData.item("shardNameTemplate", shardTemplate) .withLabel("Output Shard Name Template"), DEFAULT_SHARD_TEMPLATE) http://git-wip-us.apache.org/repos/asf/beam/blob/68d42f9b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java new file mode 100644 index 0000000..70620fb --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java @@ -0,0 +1,368 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io; + +import static org.apache.beam.sdk.io.TFRecordIO.CompressionType; +import static org.apache.beam.sdk.io.TFRecordIO.CompressionType.AUTO; +import static org.apache.beam.sdk.io.TFRecordIO.CompressionType.GZIP; +import static org.apache.beam.sdk.io.TFRecordIO.CompressionType.NONE; +import static org.apache.beam.sdk.io.TFRecordIO.CompressionType.ZLIB; +import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; +import static org.hamcrest.Matchers.isIn; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import com.google.common.collect.Lists; +import com.google.common.io.BaseEncoding; +import com.google.common.io.ByteStreams; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.file.FileVisitResult; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.SimpleFileVisitor; +import java.nio.file.attribute.BasicFileAttributes; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.testing.NeedsRunner; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.values.PCollection; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for TFRecordIO Read and Write transforms. + */ +@RunWith(JUnit4.class) +public class TFRecordIOTest { + + /* + From https://github.com/apache/beam/blob/master/sdks/python/apache_beam/io/tfrecordio_test.py + Created by running following code in python: + >>> import tensorflow as tf + >>> import base64 + >>> writer = tf.python_io.TFRecordWriter('/tmp/python_foo.tfrecord') + >>> writer.write('foo') + >>> writer.close() + >>> with open('/tmp/python_foo.tfrecord', 'rb') as f: + ... data = base64.b64encode(f.read()) + ... print data + */ + private static final String FOO_RECORD_BASE64 = "AwAAAAAAAACwmUkOZm9vYYq+/g=="; + + // Same as above but containing two records ['foo', 'bar'] + private static final String FOO_BAR_RECORD_BASE64 = + "AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg="; + private static final String BAR_FOO_RECORD_BASE64 = + "AwAAAAAAAACwmUkOYmFyRgDlyAMAAAAAAAAAsJlJDmZvb2GKvv4="; + + private static final String[] FOO_RECORDS = {"foo"}; + private static final String[] FOO_BAR_RECORDS = {"foo", "bar"}; + + private static final Iterable<String> EMPTY = Collections.emptyList(); + private static final Iterable<String> LARGE = makeLines(5000); + + private static Path tempFolder; + + @Rule + public TestPipeline p = TestPipeline.create(); + + @Rule + public TestPipeline p2 = TestPipeline.create(); + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + @BeforeClass + public static void setupClass() throws IOException { + tempFolder = Files.createTempDirectory("TFRecordIOTest"); + } + + @AfterClass + public static void teardownClass() throws IOException { + Files.walkFileTree(tempFolder, new SimpleFileVisitor<Path>() { + @Override + public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException { + Files.delete(file); + return FileVisitResult.CONTINUE; + } + + @Override + public FileVisitResult postVisitDirectory(Path dir, IOException exc) throws IOException { + Files.delete(dir); + return FileVisitResult.CONTINUE; + } + }); + } + + @Test + public void testReadNamed() { + p.enableAbandonedNodeEnforcement(false); + + assertEquals( + "TFRecordIO.Read/Read.out", + p.apply(TFRecordIO.Read.withoutValidation().from("foo.*")).getName()); + assertEquals( + "MyRead/Read.out", + p.apply("MyRead", TFRecordIO.Read.withoutValidation().from("foo.*")).getName()); + } + + @Test + public void testReadDisplayData() { + TFRecordIO.Read.Bound read = TFRecordIO.Read + .from("foo.*") + .withCompressionType(GZIP) + .withoutValidation(); + + DisplayData displayData = DisplayData.from(read); + + assertThat(displayData, hasDisplayItem("filePattern", "foo.*")); + assertThat(displayData, hasDisplayItem("compressionType", GZIP.toString())); + assertThat(displayData, hasDisplayItem("validation", false)); + } + + @Test + public void testWriteDisplayData() { + TFRecordIO.Write.Bound write = TFRecordIO.Write + .to("foo") + .withSuffix("bar") + .withShardNameTemplate("-SS-of-NN-") + .withNumShards(100) + .withCompressionType(GZIP) + .withoutValidation(); + + DisplayData displayData = DisplayData.from(write); + + assertThat(displayData, hasDisplayItem("filePrefix", "foo")); + assertThat(displayData, hasDisplayItem("fileSuffix", "bar")); + assertThat(displayData, hasDisplayItem("shardNameTemplate", "-SS-of-NN-")); + assertThat(displayData, hasDisplayItem("numShards", 100)); + assertThat(displayData, hasDisplayItem("compressionType", GZIP.toString())); + assertThat(displayData, hasDisplayItem("validation", false)); + } + + @Test + @Category(NeedsRunner.class) + public void testReadOne() throws Exception { + runTestRead(FOO_RECORD_BASE64, FOO_RECORDS); + } + + @Test + @Category(NeedsRunner.class) + public void testReadTwo() throws Exception { + runTestRead(FOO_BAR_RECORD_BASE64, FOO_BAR_RECORDS); + } + + @Test + @Category(NeedsRunner.class) + public void testWriteOne() throws Exception { + runTestWrite(FOO_RECORDS, FOO_RECORD_BASE64); + } + + @Test + @Category(NeedsRunner.class) + public void testWriteTwo() throws Exception { + runTestWrite(FOO_BAR_RECORDS, FOO_BAR_RECORD_BASE64, BAR_FOO_RECORD_BASE64); + } + + @Test + @Category(NeedsRunner.class) + public void testReadInvalidRecord() throws Exception { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage("Not a valid TFRecord. Fewer than 12 bytes."); + System.out.println("abr".getBytes().length); + runTestRead("bar".getBytes(), new String[0]); + } + + @Test + @Category(NeedsRunner.class) + public void testReadInvalidLengthMask() throws Exception { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage("Mismatch of length mask"); + byte[] data = BaseEncoding.base64().decode(FOO_RECORD_BASE64); + data[9] += 1; + runTestRead(data, FOO_RECORDS); + } + + @Test + @Category(NeedsRunner.class) + public void testReadInvalidDataMask() throws Exception { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage("Mismatch of data mask"); + byte[] data = BaseEncoding.base64().decode(FOO_RECORD_BASE64); + data[16] += 1; + runTestRead(data, FOO_RECORDS); + } + + private void runTestRead(String base64, String[] expected) throws IOException { + runTestRead(BaseEncoding.base64().decode(base64), expected); + } + + private void runTestRead(byte[] data, String[] expected) throws IOException { + File tmpFile = Files.createTempFile(tempFolder, "file", ".tfrecords").toFile(); + String filename = tmpFile.getPath(); + + FileOutputStream fos = new FileOutputStream(tmpFile); + fos.write(data); + fos.close(); + + TFRecordIO.Read.Bound read = TFRecordIO.Read.from(filename); + PCollection<String> output = p.apply(read).apply(ParDo.of(new ByteArrayToString())); + + PAssert.that(output).containsInAnyOrder(expected); + p.run(); + } + + private void runTestWrite(String[] elems, String ...base64) throws IOException { + File tmpFile = Files.createTempFile(tempFolder, "file", ".tfrecords").toFile(); + String filename = tmpFile.getPath(); + + PCollection<byte[]> input = p.apply(Create.of(Arrays.asList(elems))) + .apply(ParDo.of(new StringToByteArray())); + + TFRecordIO.Write.Bound write = TFRecordIO.Write.to(filename).withoutSharding(); + input.apply(write); + + p.run(); + + FileInputStream fis = new FileInputStream(tmpFile); + String written = BaseEncoding.base64().encode(ByteStreams.toByteArray(fis)); + // bytes written may vary depending the order of elems + assertThat(written, isIn(base64)); + } + + @Test + @Category(NeedsRunner.class) + public void runTestRoundTrip() throws IOException { + runTestRoundTrip(LARGE, 10, ".tfrecords", NONE, NONE); + } + + @Test + @Category(NeedsRunner.class) + public void runTestRoundTripWithEmptyData() throws IOException { + runTestRoundTrip(EMPTY, 10, ".tfrecords", NONE, NONE); + } + + @Test + @Category(NeedsRunner.class) + public void runTestRoundTripWithOneShards() throws IOException { + runTestRoundTrip(LARGE, 1, ".tfrecords", NONE, NONE); + } + + @Test + @Category(NeedsRunner.class) + public void runTestRoundTripWithSuffix() throws IOException { + runTestRoundTrip(LARGE, 10, ".suffix", NONE, NONE); + } + + @Test + @Category(NeedsRunner.class) + public void runTestRoundTripGzip() throws IOException { + runTestRoundTrip(LARGE, 10, ".tfrecords", GZIP, GZIP); + } + + @Test + @Category(NeedsRunner.class) + public void runTestRoundTripZlib() throws IOException { + runTestRoundTrip(LARGE, 10, ".tfrecords", ZLIB, ZLIB); + } + + @Test + @Category(NeedsRunner.class) + public void runTestRoundTripUncompressedFilesWithAuto() throws IOException { + runTestRoundTrip(LARGE, 10, ".tfrecords", NONE, AUTO); + } + + @Test + @Category(NeedsRunner.class) + public void runTestRoundTripGzipFilesWithAuto() throws IOException { + runTestRoundTrip(LARGE, 10, ".tfrecords", GZIP, AUTO); + } + + @Test + @Category(NeedsRunner.class) + public void runTestRoundTripZlibFilesWithAuto() throws IOException { + runTestRoundTrip(LARGE, 10, ".tfrecords", ZLIB, AUTO); + } + + private void runTestRoundTrip(Iterable<String> elems, + int numShards, + String suffix, + CompressionType writeCompressionType, + CompressionType readCompressionType) throws IOException { + String outputName = "file"; + Path baseDir = Files.createTempDirectory(tempFolder, "test-rt"); + String baseFilename = baseDir.resolve(outputName).toString(); + + TFRecordIO.Write.Bound write = TFRecordIO.Write.to(baseFilename) + .withNumShards(numShards) + .withSuffix(suffix) + .withCompressionType(writeCompressionType); + p.apply(Create.of(elems).withCoder(StringUtf8Coder.of())) + .apply(ParDo.of(new StringToByteArray())) + .apply(write); + p.run(); + + TFRecordIO.Read.Bound read = TFRecordIO.Read.from(baseFilename + "*") + .withCompressionType(readCompressionType); + PCollection<String> output = p2.apply(read).apply(ParDo.of(new ByteArrayToString())); + + PAssert.that(output).containsInAnyOrder(elems); + p2.run(); + } + + private static Iterable<String> makeLines(int n) { + List<String> ret = Lists.newArrayList(); + for (int i = 0; i < n; ++i) { + ret.add("word" + i); + } + return ret; + } + + static class ByteArrayToString extends DoFn<byte[], String> { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(new String(c.element())); + } + } + + static class StringToByteArray extends DoFn<String, byte[]> { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element().getBytes()); + } + } + +} http://git-wip-us.apache.org/repos/asf/beam/blob/68d42f9b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java index cd94dc5..713cb71 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java @@ -206,7 +206,7 @@ public class TextIOTest { } @AfterClass - public static void testdownClass() throws IOException { + public static void teardownClass() throws IOException { Files.walkFileTree(tempFolder, new SimpleFileVisitor<Path>() { @Override public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException {
