Rollback revert "Migrate TextIO.Write to a custom sink" Note for user requested sharding limits to be supported, each pipeline runner must support applying those sharding limits.
DirectPipelineRunner and Google Cloud Dataflow supports sharding limits. ----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=115500204 Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/639e9d95 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/639e9d95 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/639e9d95 Branch: refs/heads/master Commit: 639e9d95b61704ae1740a0a1f02f76c3d480fa48 Parents: 045e343 Author: lcwik <[email protected]> Authored: Wed Feb 24 15:24:22 2016 -0800 Committer: Davor Bonaci <[email protected]> Committed: Thu Feb 25 23:58:28 2016 -0800 ---------------------------------------------------------------------- .../cloud/dataflow/sdk/io/FileBasedSink.java | 17 +- .../google/cloud/dataflow/sdk/io/TextIO.java | 187 +++++++++---------- .../sdk/runners/DataflowPipelineRunner.java | 131 +++++++++++-- .../sdk/runners/DataflowPipelineTranslator.java | 6 - .../sdk/runners/DirectPipelineRunner.java | 89 +++++++++ .../sdk/runners/dataflow/TextIOTranslator.java | 91 --------- .../dataflow/sdk/io/FileBasedSinkTest.java | 29 ++- .../cloud/dataflow/sdk/io/TextIOTest.java | 22 --- .../sdk/runners/DataflowPipelineRunnerTest.java | 21 +-- .../runners/DataflowPipelineTranslatorTest.java | 4 +- .../sdk/runners/DirectPipelineRunnerTest.java | 80 +++++++- .../dataflow/sdk/runners/TransformTreeTest.java | 9 +- 12 files changed, 414 insertions(+), 272 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/639e9d95/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/FileBasedSink.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/FileBasedSink.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/FileBasedSink.java index 7c30167..dda500c 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/FileBasedSink.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/FileBasedSink.java @@ -355,7 +355,7 @@ public abstract class FileBasedSink<T> extends Sink<T> { String baseOutputFilename = getSink().baseOutputFilename; String fileNamingTemplate = getSink().fileNamingTemplate; - String suffix = (extension.length() == 0) ? extension : ("." + extension); + String suffix = getFileExtension(extension); for (int i = 0; i < numFiles; i++) { destFilenames.add(IOChannelUtils.constructName( baseOutputFilename, fileNamingTemplate, suffix, i, numFiles)); @@ -364,6 +364,21 @@ public abstract class FileBasedSink<T> extends Sink<T> { } /** + * Returns the file extension to be used. If the user did not request a file + * extension then this method returns the empty string. Otherwise this method + * adds a {@code "."} to the beginning of the users extension if one is not present. + */ + private String getFileExtension(String usersExtension) { + if (usersExtension == null || usersExtension.isEmpty()) { + return ""; + } + if (usersExtension.startsWith(".")) { + return usersExtension; + } + return "." + usersExtension; + } + + /** * Removes temporary output files. Uses the temporary filename to find files to remove. * * <p>Can be called from subclasses that override {@link FileBasedWriteOperation#finalize}. http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/639e9d95/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/TextIO.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/TextIO.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/TextIO.java index 0bb2861..d342f25 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/TextIO.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/TextIO.java @@ -26,11 +26,9 @@ import com.google.cloud.dataflow.sdk.io.Read.Bounded; import com.google.cloud.dataflow.sdk.options.PipelineOptions; import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; -import com.google.cloud.dataflow.sdk.runners.worker.TextSink; import com.google.cloud.dataflow.sdk.transforms.PTransform; import com.google.cloud.dataflow.sdk.util.IOChannelUtils; -import com.google.cloud.dataflow.sdk.util.WindowedValue; -import com.google.cloud.dataflow.sdk.util.common.worker.Sink; +import com.google.cloud.dataflow.sdk.util.MimeTypes; import com.google.cloud.dataflow.sdk.values.PCollection; import com.google.cloud.dataflow.sdk.values.PDone; import com.google.cloud.dataflow.sdk.values.PInput; @@ -39,10 +37,13 @@ import com.google.common.base.Preconditions; import com.google.protobuf.ByteString; import java.io.IOException; +import java.io.OutputStream; import java.nio.ByteBuffer; +import java.nio.channels.Channels; import java.nio.channels.ReadableByteChannel; import java.nio.channels.SeekableByteChannel; -import java.util.List; +import java.nio.channels.WritableByteChannel; +import java.nio.charset.StandardCharsets; import java.util.NoSuchElementException; import java.util.regex.Pattern; @@ -66,7 +67,7 @@ import javax.annotation.Nullable; * * <p>See the following examples: * - * <pre> {@code + * <pre>{@code * Pipeline p = ...; * * // A simple Read of a local file (only runs locally): @@ -79,7 +80,7 @@ import javax.annotation.Nullable; * p.apply(TextIO.Read.named("ReadNumbers") * .from("gs://my_bucket/path/to/numbers-*.txt") * .withCoder(TextualIntegerCoder.of())); - * } </pre> + * }</pre> * * <p>To write a {@link PCollection} to one or more text files, use * {@link TextIO.Write}, specifying {@link TextIO.Write#to(String)} to specify @@ -94,7 +95,7 @@ import javax.annotation.Nullable; * will be overwritten. * * <p>For example: - * <pre> {@code + * <pre>{@code * // A simple Write to a local file (only runs locally): * PCollection<String> lines = ...; * lines.apply(TextIO.Write.to("/path/to/file.txt")); @@ -106,7 +107,7 @@ import javax.annotation.Nullable; * .to("gs://my_bucket/path/to/numbers") * .withSuffix(".txt") * .withCoder(TextualIntegerCoder.of())); - * } </pre> + * }</pre> * * <h3>Permissions</h3> * <p>When run using the {@link DirectPipelineRunner}, your pipeline can read and write text files @@ -477,9 +478,6 @@ public class TextIO { /** Requested number of shards. 0 for automatic. */ private final int numShards; - /** Insert a shuffle before writing to decouple parallelism when numShards != 0. */ - private final boolean forceReshard; - /** The shard template of each file written, combined with prefix and suffix. */ private final String shardTemplate; @@ -487,17 +485,16 @@ public class TextIO { private final boolean validate; Bound(Coder<T> coder) { - this(null, null, "", coder, 0, true, ShardNameTemplate.INDEX_OF_MAX, true); + this(null, null, "", coder, 0, ShardNameTemplate.INDEX_OF_MAX, true); } private Bound(String name, String filenamePrefix, String filenameSuffix, Coder<T> coder, - int numShards, boolean forceReshard, String shardTemplate, boolean validate) { + int numShards, String shardTemplate, boolean validate) { super(name); this.coder = coder; this.filenamePrefix = filenamePrefix; this.filenameSuffix = filenameSuffix; this.numShards = numShards; - this.forceReshard = forceReshard; this.shardTemplate = shardTemplate; this.validate = validate; } @@ -510,7 +507,7 @@ public class TextIO { */ public Bound<T> named(String name) { return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, - forceReshard, shardTemplate, validate); + shardTemplate, validate); } /** @@ -523,7 +520,7 @@ public class TextIO { */ public Bound<T> to(String filenamePrefix) { validateOutputComponent(filenamePrefix); - return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, forceReshard, + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, shardTemplate, validate); } @@ -537,7 +534,7 @@ public class TextIO { */ public Bound<T> withSuffix(String nameExtension) { validateOutputComponent(nameExtension); - return new Bound<>(name, filenamePrefix, nameExtension, coder, numShards, forceReshard, + return new Bound<>(name, filenamePrefix, nameExtension, coder, numShards, shardTemplate, validate); } @@ -556,30 +553,8 @@ public class TextIO { * @see ShardNameTemplate */ public Bound<T> withNumShards(int numShards) { - return withNumShards(numShards, forceReshard); - } - - /** - * Returns a transform for writing to text files that's like this one but - * that uses the provided shard count. - * - * <p>Constraining the number of shards is likely to reduce - * the performance of a pipeline. If forceReshard is true, the output - * will be shuffled to obtain the desired sharding. If it is false, - * data will not be reshuffled, but parallelism of preceeding stages - * may be constrained. Setting this value is not recommended - * unless you require a specific number of output files. - * - * <p>Does not modify this object. - * - * @param numShards the number of shards to use, or 0 to let the system - * decide. - * @param forceReshard whether to force a reshard to obtain the desired sharding. - * @see ShardNameTemplate - */ - private Bound<T> withNumShards(int numShards, boolean forceReshard) { Preconditions.checkArgument(numShards >= 0); - return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, forceReshard, + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, shardTemplate, validate); } @@ -592,7 +567,7 @@ public class TextIO { * @see ShardNameTemplate */ public Bound<T> withShardNameTemplate(String shardTemplate) { - return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, forceReshard, + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, shardTemplate, validate); } @@ -610,25 +585,7 @@ public class TextIO { * <p>Does not modify this object. */ public Bound<T> withoutSharding() { - return withoutSharding(forceReshard); - } - - /** - * Returns a transform for writing to text files that's like this one but - * that forces a single file as output. - * - * <p>Constraining the number of shards is likely to reduce - * the performance of a pipeline. Using this setting is not recommended - * unless you truly require a single output file. - * - * <p>This is a shortcut for - * {@code .withNumShards(1, forceReshard).withShardNameTemplate("")} - * - * <p>Does not modify this object. - */ - private Bound<T> withoutSharding(boolean forceReshard) { - return new Bound<>(name, filenamePrefix, filenameSuffix, coder, 1, forceReshard, "", - validate); + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, 1, "", validate); } /** @@ -640,7 +597,7 @@ public class TextIO { * @param <X> the type of the elements of the input {@link PCollection} */ public <X> Bound<X> withCoder(Coder<X> coder) { - return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, forceReshard, + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, shardTemplate, validate); } @@ -655,7 +612,7 @@ public class TextIO { * <p>Does not modify this object. */ public Bound<T> withoutValidation() { - return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, forceReshard, + return new Bound<>(name, filenamePrefix, filenameSuffix, coder, numShards, shardTemplate, false); } @@ -665,14 +622,13 @@ public class TextIO { throw new IllegalStateException( "need to set the filename prefix of a TextIO.Write transform"); } - if (numShards > 0 && forceReshard) { - // Reshard and re-apply a version of this write without resharding. - return input - .apply(new FileBasedSink.ReshardForWrite<T>()) - .apply(withNumShards(numShards, false)); - } else { - return PDone.in(input.getPipeline()); - } + + // Note that custom sinks currently do not expose sharding controls. + // Thus pipeline runner writers need to individually add support internally to + // apply user requested sharding limits. + return input.apply("Write", com.google.cloud.dataflow.sdk.io.Write.to( + new TextSink<>( + filenamePrefix, filenameSuffix, shardTemplate, coder))); } /** @@ -710,17 +666,6 @@ public class TextIO { public boolean needsValidation() { return validate; } - - static { - DirectPipelineRunner.registerDefaultTransformEvaluator( - Bound.class, new DirectPipelineRunner.TransformEvaluator<Bound>() { - @Override - public void evaluate( - Bound transform, DirectPipelineRunner.EvaluationContext context) { - evaluateWriteHelper(transform, context); - } - }); - } } } @@ -978,24 +923,70 @@ public class TextIO { } } - private static <T> void evaluateWriteHelper( - Write.Bound<T> transform, DirectPipelineRunner.EvaluationContext context) { - List<T> elems = context.getPCollection(context.getInput(transform)); - int numShards = transform.numShards; - if (numShards < 1) { - // System gets to choose. For direct mode, choose 1. - numShards = 1; + /** + * A {@link FileBasedSink} for text files. Produces text files with the new line separator + * {@code '\n'} represented in {@code UTF-8} format as the record separator. + * Each record (including the last) is terminated. + */ + @VisibleForTesting + static class TextSink<T> extends FileBasedSink<T> { + private final Coder<T> coder; + + @VisibleForTesting + TextSink( + String baseOutputFilename, String extension, String fileNameTemplate, Coder<T> coder) { + super(baseOutputFilename, extension, fileNameTemplate); + this.coder = coder; } - TextSink<WindowedValue<T>> writer = TextSink.createForDirectPipelineRunner( - transform.filenamePrefix, transform.getShardNameTemplate(), transform.filenameSuffix, - numShards, true, null, null, transform.coder); - try (Sink.SinkWriter<WindowedValue<T>> sink = writer.writer()) { - for (T elem : elems) { - sink.add(WindowedValue.valueInGlobalWindow(elem)); - } - } catch (IOException exn) { - throw new RuntimeException( - "unable to write to output file \"" + transform.filenamePrefix + "\"", exn); + + @Override + public FileBasedSink.FileBasedWriteOperation<T> createWriteOperation(PipelineOptions options) { + return new TextWriteOperation<>(this, coder); + } + + /** + * A {@link com.google.cloud.dataflow.sdk.io.FileBasedSink.FileBasedWriteOperation + * FileBasedWriteOperation} for text files. + */ + private static class TextWriteOperation<T> extends FileBasedWriteOperation<T> { + private final Coder<T> coder; + + private TextWriteOperation(TextSink<T> sink, Coder<T> coder) { + super(sink); + this.coder = coder; + } + + @Override + public FileBasedWriter<T> createWriter(PipelineOptions options) throws Exception { + return new TextWriter<>(this, coder); + } + } + + /** + * A {@link com.google.cloud.dataflow.sdk.io.FileBasedSink.FileBasedWriter FileBasedWriter} + * for text files. + */ + private static class TextWriter<T> extends FileBasedWriter<T> { + private static final byte[] NEWLINE = "\n".getBytes(StandardCharsets.UTF_8); + private final Coder<T> coder; + private OutputStream out; + + public TextWriter(FileBasedWriteOperation<T> writeOperation, Coder<T> coder) { + super(writeOperation); + this.mimeType = MimeTypes.TEXT; + this.coder = coder; + } + + @Override + protected void prepareWrite(WritableByteChannel channel) throws Exception { + out = Channels.newOutputStream(channel); + } + + @Override + public void write(T value) throws Exception { + coder.encode(value, out, Context.OUTER); + out.write(NEWLINE); + } } } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/639e9d95/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunner.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunner.java index 54fadea..06b2295 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunner.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunner.java @@ -338,6 +338,7 @@ public class DataflowPipelineRunner extends PipelineRunner<DataflowPipelineJob> builder.put(Window.Bound.class, AssignWindows.class); builder.put(Write.Bound.class, BatchWrite.class); builder.put(AvroIO.Write.Bound.class, BatchAvroIOWrite.class); + builder.put(TextIO.Write.Bound.class, BatchTextIOWrite.class); if (options.getExperiments() == null || !options.getExperiments().contains("disable_ism_side_input")) { builder.put(View.AsMap.class, BatchViewAsMap.class); @@ -2001,6 +2002,111 @@ public class DataflowPipelineRunner extends PipelineRunner<DataflowPipelineJob> /** * Specialized implementation which overrides + * {@link com.google.cloud.dataflow.sdk.io.TextIO.Write.Bound TextIO.Write.Bound} with + * a native sink instead of a custom sink as workaround until custom sinks + * have support for sharding controls. + */ + private static class BatchTextIOWrite<T> extends PTransform<PCollection<T>, PDone> { + private final TextIO.Write.Bound<T> transform; + /** + * Builds an instance of this class from the overridden transform. + */ + @SuppressWarnings("unused") // used via reflection in DataflowPipelineRunner#apply() + public BatchTextIOWrite(DataflowPipelineRunner runner, TextIO.Write.Bound<T> transform) { + this.transform = transform; + } + + @Override + public PDone apply(PCollection<T> input) { + if (transform.getNumShards() > 0) { + return input + .apply(new ReshardForWrite<T>()) + .apply(new BatchTextIONativeWrite<>(transform)); + } else { + return transform.apply(input); + } + } + } + + /** + * This {@link PTransform} is used by the {@link DataflowPipelineTranslator} as a way + * to provide the native definition of the Text sink. + */ + private static class BatchTextIONativeWrite<T> extends PTransform<PCollection<T>, PDone> { + private final TextIO.Write.Bound<T> transform; + public BatchTextIONativeWrite(TextIO.Write.Bound<T> transform) { + this.transform = transform; + } + + @Override + public PDone apply(PCollection<T> input) { + return PDone.in(input.getPipeline()); + } + + static { + DataflowPipelineTranslator.registerTransformTranslator( + BatchTextIONativeWrite.class, new BatchTextIONativeWriteTranslator()); + } + } + + /** + * TextIO.Write.Bound support code for the Dataflow backend when applying parallelism limits + * through user requested sharding limits. + */ + private static class BatchTextIONativeWriteTranslator + implements TransformTranslator<BatchTextIONativeWrite<?>> { + @SuppressWarnings("unchecked") + @Override + public void translate(@SuppressWarnings("rawtypes") BatchTextIONativeWrite transform, + TranslationContext context) { + translateWriteHelper(transform, transform.transform, context); + } + + private <T> void translateWriteHelper( + BatchTextIONativeWrite<T> transform, + TextIO.Write.Bound<T> originalTransform, + TranslationContext context) { + // Note that the original transform can not be used during add step/add input + // and is only passed in to get properties from it. + + checkState(originalTransform.getNumShards() > 0, + "Native TextSink is expected to only be used when sharding controls are required."); + + context.addStep(transform, "ParallelWrite"); + context.addInput(PropertyNames.PARALLEL_INPUT, context.getInput(transform)); + + // TODO: drop this check when server supports alternative templates. + switch (originalTransform.getShardTemplate()) { + case ShardNameTemplate.INDEX_OF_MAX: + break; // supported by server + case "": + // Empty shard template allowed - forces single output. + Preconditions.checkArgument(originalTransform.getNumShards() <= 1, + "Num shards must be <= 1 when using an empty sharding template"); + break; + default: + throw new UnsupportedOperationException("Shard template " + + originalTransform.getShardTemplate() + + " not yet supported by Dataflow service"); + } + + // TODO: How do we want to specify format and + // format-specific properties? + context.addInput(PropertyNames.FORMAT, "text"); + context.addInput(PropertyNames.FILENAME_PREFIX, originalTransform.getFilenamePrefix()); + context.addInput(PropertyNames.SHARD_NAME_TEMPLATE, + originalTransform.getShardNameTemplate()); + context.addInput(PropertyNames.FILENAME_SUFFIX, originalTransform.getFilenameSuffix()); + context.addInput(PropertyNames.VALIDATE_SINK, originalTransform.needsValidation()); + context.addInput(PropertyNames.NUM_SHARDS, (long) originalTransform.getNumShards()); + context.addEncodingInput( + WindowedValue.getValueOnlyCoder(originalTransform.getCoder())); + + } + } + + /** + * Specialized implementation which overrides * {@link com.google.cloud.dataflow.sdk.io.AvroIO.Write.Bound AvroIO.Write.Bound} with * a native sink instead of a custom sink as workaround until custom sinks * have support for sharding controls. @@ -2018,7 +2124,9 @@ public class DataflowPipelineRunner extends PipelineRunner<DataflowPipelineJob> @Override public PDone apply(PCollection<T> input) { if (transform.getNumShards() > 0) { - return input.apply(new ReshardForWrite<T>()).apply(new BatchAvroIONativeWrite<>(transform)); + return input + .apply(new ReshardForWrite<T>()) + .apply(new BatchAvroIONativeWrite<>(transform)); } else { return transform.apply(input); } @@ -2031,7 +2139,6 @@ public class DataflowPipelineRunner extends PipelineRunner<DataflowPipelineJob> */ private static class BatchAvroIONativeWrite<T> extends PTransform<PCollection<T>, PDone> { private final AvroIO.Write.Bound<T> transform; - public BatchAvroIONativeWrite(AvroIO.Write.Bound<T> transform) { this.transform = transform; } @@ -2055,8 +2162,7 @@ public class DataflowPipelineRunner extends PipelineRunner<DataflowPipelineJob> implements TransformTranslator<BatchAvroIONativeWrite<?>> { @SuppressWarnings("unchecked") @Override - public void translate( - @SuppressWarnings("rawtypes") BatchAvroIONativeWrite transform, + public void translate(@SuppressWarnings("rawtypes") BatchAvroIONativeWrite transform, TranslationContext context) { translateWriteHelper(transform, transform.transform, context); } @@ -2068,8 +2174,7 @@ public class DataflowPipelineRunner extends PipelineRunner<DataflowPipelineJob> // Note that the original transform can not be used during add step/add input // and is only passed in to get properties from it. - checkState( - originalTransform.getNumShards() > 0, + checkState(originalTransform.getNumShards() > 0, "Native AvroSink is expected to only be used when sharding controls are required."); context.addStep(transform, "ParallelWrite"); @@ -2078,18 +2183,16 @@ public class DataflowPipelineRunner extends PipelineRunner<DataflowPipelineJob> // TODO: drop this check when server supports alternative templates. switch (originalTransform.getShardTemplate()) { case ShardNameTemplate.INDEX_OF_MAX: - break; // supported by server + break; // supported by server case "": // Empty shard template allowed - forces single output. - Preconditions.checkArgument( - originalTransform.getNumShards() <= 1, + Preconditions.checkArgument(originalTransform.getNumShards() <= 1, "Num shards must be <= 1 when using an empty sharding template"); break; default: - throw new UnsupportedOperationException( - "Shard template " - + originalTransform.getShardTemplate() - + " not yet supported by Dataflow service"); + throw new UnsupportedOperationException("Shard template " + + originalTransform.getShardTemplate() + + " not yet supported by Dataflow service"); } context.addInput(PropertyNames.FORMAT, "avro"); @@ -2097,9 +2200,7 @@ public class DataflowPipelineRunner extends PipelineRunner<DataflowPipelineJob> context.addInput(PropertyNames.SHARD_NAME_TEMPLATE, originalTransform.getShardTemplate()); context.addInput(PropertyNames.FILENAME_SUFFIX, originalTransform.getFilenameSuffix()); context.addInput(PropertyNames.VALIDATE_SINK, originalTransform.needsValidation()); - context.addInput(PropertyNames.NUM_SHARDS, (long) originalTransform.getNumShards()); - context.addEncodingInput( WindowedValue.getValueOnlyCoder( AvroCoder.of(originalTransform.getType(), originalTransform.getSchema()))); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/639e9d95/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java index 885260e..ae3a403 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslator.java @@ -44,14 +44,12 @@ import com.google.cloud.dataflow.sdk.coders.IterableCoder; import com.google.cloud.dataflow.sdk.io.BigQueryIO; import com.google.cloud.dataflow.sdk.io.PubsubIO; import com.google.cloud.dataflow.sdk.io.Read; -import com.google.cloud.dataflow.sdk.io.TextIO; import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; import com.google.cloud.dataflow.sdk.options.StreamingOptions; import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner.GroupByKeyAndSortValuesOnly; import com.google.cloud.dataflow.sdk.runners.dataflow.BigQueryIOTranslator; import com.google.cloud.dataflow.sdk.runners.dataflow.PubsubIOTranslator; import com.google.cloud.dataflow.sdk.runners.dataflow.ReadTranslator; -import com.google.cloud.dataflow.sdk.runners.dataflow.TextIOTranslator; import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; import com.google.cloud.dataflow.sdk.transforms.Combine; import com.google.cloud.dataflow.sdk.transforms.Create; @@ -998,7 +996,6 @@ public class DataflowPipelineTranslator { } }); - registerTransformTranslator( Window.Bound.class, new DataflowPipelineTranslator.TransformTranslator<Window.Bound>() { @@ -1037,9 +1034,6 @@ public class DataflowPipelineTranslator { DataflowPipelineRunner.StreamingPubsubIOWrite.class, new PubsubIOTranslator.WriteTranslator()); - registerTransformTranslator( - TextIO.Write.Bound.class, new TextIOTranslator.WriteTranslator()); - registerTransformTranslator(Read.Bounded.class, new ReadTranslator()); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/639e9d95/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java index 332a496..4543b5a 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java @@ -17,6 +17,7 @@ package com.google.cloud.dataflow.sdk.runners; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import com.google.cloud.dataflow.sdk.Pipeline; import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor; @@ -24,6 +25,8 @@ import com.google.cloud.dataflow.sdk.PipelineResult; import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.coders.ListCoder; +import com.google.cloud.dataflow.sdk.io.FileBasedSink; +import com.google.cloud.dataflow.sdk.io.TextIO; import com.google.cloud.dataflow.sdk.options.DirectPipelineOptions; import com.google.cloud.dataflow.sdk.options.PipelineOptions; import com.google.cloud.dataflow.sdk.options.PipelineOptions.CheckEnabled; @@ -36,6 +39,8 @@ import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; import com.google.cloud.dataflow.sdk.transforms.DoFn; import com.google.cloud.dataflow.sdk.transforms.PTransform; import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Partition; +import com.google.cloud.dataflow.sdk.transforms.Partition.PartitionFn; import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; import com.google.cloud.dataflow.sdk.util.AppliedCombineFn; import com.google.cloud.dataflow.sdk.util.IOChannelUtils; @@ -51,6 +56,7 @@ import com.google.cloud.dataflow.sdk.values.KV; import com.google.cloud.dataflow.sdk.values.PCollection; import com.google.cloud.dataflow.sdk.values.PCollectionList; import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PDone; import com.google.cloud.dataflow.sdk.values.PInput; import com.google.cloud.dataflow.sdk.values.POutput; import com.google.cloud.dataflow.sdk.values.PValue; @@ -232,6 +238,8 @@ public class DirectPipelineRunner PTransform<InputT, OutputT> transform, InputT input) { if (transform instanceof Combine.GroupedValues) { return (OutputT) applyTestCombine((Combine.GroupedValues) transform, (PCollection) input); + } else if (transform instanceof TextIO.Write.Bound) { + return (OutputT) applyTextIOWrite((TextIO.Write.Bound) transform, (PCollection<?>) input); } else { return super.apply(transform, input); } @@ -253,6 +261,87 @@ public class DirectPipelineRunner return output; } + private static class ElementProcessingOrderPartitionFn<T> implements PartitionFn<T> { + private int elementNumber; + @Override + public int partitionFor(T elem, int numPartitions) { + return elementNumber++ % numPartitions; + } + } + + /** + * Applies TextIO.Write honoring user requested sharding controls (i.e. withNumShards) + * by applying a partition function based upon the number of shards the user requested. + */ + private static class DirectTextIOWrite<T> extends PTransform<PCollection<T>, PDone> { + private final TextIO.Write.Bound<T> transform; + + private DirectTextIOWrite(TextIO.Write.Bound<T> transform) { + this.transform = transform; + } + + @Override + public PDone apply(PCollection<T> input) { + checkState(transform.getNumShards() > 1, + "DirectTextIOWrite is expected to only be used when sharding controls are required."); + + // Evenly distribute all the elements across the partitions. + PCollectionList<T> partitionedElements = + input.apply(Partition.of(transform.getNumShards(), + new ElementProcessingOrderPartitionFn<T>())); + + // For each input PCollection partition, create a write transform that represents + // one of the specific shards. + for (int i = 0; i < transform.getNumShards(); ++i) { + /* + * This logic mirrors the file naming strategy within + * {@link FileBasedSink#generateDestinationFilenames()} + */ + String outputFilename = IOChannelUtils.constructName( + transform.getFilenamePrefix(), + transform.getShardNameTemplate(), + getFileExtension(transform.getFilenameSuffix()), + i, + transform.getNumShards()); + + String transformName = String.format("%s(Shard:%s)", transform.getName(), i); + partitionedElements.get(i).apply(transformName, + transform.withNumShards(1).withShardNameTemplate("").withSuffix("").to(outputFilename)); + } + return PDone.in(input.getPipeline()); + } + } + + /** + * Returns the file extension to be used. If the user did not request a file + * extension then this method returns the empty string. Otherwise this method + * adds a {@code "."} to the beginning of the users extension if one is not present. + * + * <p>This is copied from {@link FileBasedSink} to not expose it. + */ + private static String getFileExtension(String usersExtension) { + if (usersExtension == null || usersExtension.isEmpty()) { + return ""; + } + if (usersExtension.startsWith(".")) { + return usersExtension; + } + return "." + usersExtension; + } + + /** + * Apply the override for TextIO.Write.Bound if the user requested sharding controls + * greater than one. + */ + private <T> PDone applyTextIOWrite(TextIO.Write.Bound<T> transform, PCollection<T> input) { + if (transform.getNumShards() <= 1) { + // By default, the DirectPipelineRunner outputs to only 1 shard. Since the user never + // requested sharding controls greater than 1, we default to outputting to 1 file. + return super.apply(transform.withNumShards(1), input); + } + return input.apply(new DirectTextIOWrite<>(transform)); + } + /** * The implementation may split the {@link KeyedCombineFn} into ADD, MERGE and EXTRACT phases ( * see {@code com.google.cloud.dataflow.sdk.runners.worker.CombineValuesFn}). In order to emulate http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/639e9d95/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/TextIOTranslator.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/TextIOTranslator.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/TextIOTranslator.java deleted file mode 100644 index d6c96c3..0000000 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/dataflow/TextIOTranslator.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright (C) 2015 Google Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not - * use this file except in compliance with the License. You may obtain a copy of - * the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations under - * the License. - */ - -package com.google.cloud.dataflow.sdk.runners.dataflow; - -import com.google.cloud.dataflow.sdk.io.ShardNameTemplate; -import com.google.cloud.dataflow.sdk.io.TextIO; -import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TransformTranslator; -import com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator.TranslationContext; -import com.google.cloud.dataflow.sdk.util.PathValidator; -import com.google.cloud.dataflow.sdk.util.PropertyNames; -import com.google.cloud.dataflow.sdk.util.WindowedValue; -import com.google.common.base.Preconditions; - -/** - * TextIO transform support code for the Dataflow backend. - */ -public class TextIOTranslator { - /** - * Implements TextIO Write translation for the Dataflow backend. - */ - @SuppressWarnings({"rawtypes", "unchecked"}) - public static class WriteTranslator implements TransformTranslator<TextIO.Write.Bound> { - @Override - public void translate( - TextIO.Write.Bound transform, - TranslationContext context) { - translateWriteHelper(transform, context); - } - - private <T> void translateWriteHelper( - TextIO.Write.Bound<T> transform, - TranslationContext context) { - if (context.getPipelineOptions().isStreaming()) { - throw new IllegalArgumentException("TextIO not supported in streaming mode."); - } - - PathValidator validator = context.getPipelineOptions().getPathValidator(); - String filenamePrefix = validator.validateOutputFilePrefixSupported( - transform.getFilenamePrefix()); - - context.addStep(transform, "ParallelWrite"); - context.addInput(PropertyNames.PARALLEL_INPUT, context.getInput(transform)); - - // TODO: drop this check when server supports alternative templates. - switch (transform.getShardTemplate()) { - case ShardNameTemplate.INDEX_OF_MAX: - break; // supported by server - case "": - // Empty shard template allowed - forces single output. - Preconditions.checkArgument(transform.getNumShards() <= 1, - "Num shards must be <= 1 when using an empty sharding template"); - break; - default: - throw new UnsupportedOperationException("Shard template " - + transform.getShardTemplate() - + " not yet supported by Dataflow service"); - } - - // TODO: How do we want to specify format and - // format-specific properties? - context.addInput(PropertyNames.FORMAT, "text"); - context.addInput(PropertyNames.FILENAME_PREFIX, filenamePrefix); - context.addInput(PropertyNames.SHARD_NAME_TEMPLATE, - transform.getShardNameTemplate()); - context.addInput(PropertyNames.FILENAME_SUFFIX, transform.getFilenameSuffix()); - context.addInput(PropertyNames.VALIDATE_SINK, transform.needsValidation()); - - long numShards = transform.getNumShards(); - if (numShards > 0) { - context.addInput(PropertyNames.NUM_SHARDS, numShards); - } - - context.addEncodingInput( - WindowedValue.getValueOnlyCoder(transform.getCoder())); - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/639e9d95/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/FileBasedSinkTest.java ---------------------------------------------------------------------- diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/FileBasedSinkTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/FileBasedSinkTest.java index 8236ae6..da23f3a 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/FileBasedSinkTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/FileBasedSinkTest.java @@ -56,7 +56,6 @@ public class FileBasedSinkTest { private String baseOutputFilename = "output"; private String baseTemporaryFilename = "temp"; - private String testExtension = "test"; private String appendToTempFolder(String filename) { return Paths.get(tmpFolder.getRoot().getPath(), filename).toString(); @@ -314,7 +313,7 @@ public class FileBasedSinkTest { public void testGenerateOutputFilenamesWithTemplate() { List<String> expected; List<String> actual; - SimpleSink sink = buildSink(".SS.of.NN"); + SimpleSink sink = new SimpleSink(getBaseOutputFilename(), "test", ".SS.of.NN"); SimpleSink.SimpleWriteOperation writeOp = new SimpleSink.SimpleWriteOperation(sink); expected = Arrays.asList(appendToTempFolder("output.00.of.03.test"), @@ -329,6 +328,23 @@ public class FileBasedSinkTest { expected = new ArrayList<>(); actual = writeOp.generateDestinationFilenames(0); assertEquals(expected, actual); + + // Also validate that we handle the case where the user specified "." that we do + // not prefix an additional "." making "..test" + sink = new SimpleSink(getBaseOutputFilename(), ".test", ".SS.of.NN"); + writeOp = new SimpleSink.SimpleWriteOperation(sink); + expected = Arrays.asList(appendToTempFolder("output.00.of.03.test"), + appendToTempFolder("output.01.of.03.test"), appendToTempFolder("output.02.of.03.test")); + actual = writeOp.generateDestinationFilenames(3); + assertEquals(expected, actual); + + expected = Arrays.asList(appendToTempFolder("output.00.of.01.test")); + actual = writeOp.generateDestinationFilenames(1); + assertEquals(expected, actual); + + expected = new ArrayList<>(); + actual = writeOp.generateDestinationFilenames(0); + assertEquals(expected, actual); } /** @@ -457,14 +473,7 @@ public class FileBasedSinkTest { * Build a SimpleSink with default options. */ private SimpleSink buildSink() { - return new SimpleSink(getBaseOutputFilename(), testExtension); - } - - /** - * Build a SimpleSink with default options and the given shard template. - */ - private SimpleSink buildSink(String shardTemplate) { - return new SimpleSink(getBaseOutputFilename(), testExtension, shardTemplate); + return new SimpleSink(getBaseOutputFilename(), "test"); } /** http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/639e9d95/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java ---------------------------------------------------------------------- diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java index 6ad81e4..0a8e381 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java @@ -264,28 +264,6 @@ public class TextIOTest { } @Test - public void testWriteSharded() throws IOException { - File outFolder = tmpFolder.newFolder(); - String filename = outFolder.toPath().resolve("output").toString(); - - Pipeline p = TestPipeline.create(); - - PCollection<String> input = - p.apply(Create.of(Arrays.asList(LINES_ARRAY)) - .withCoder(StringUtf8Coder.of())); - - input.apply(TextIO.Write.to(filename).withNumShards(2).withSuffix(".txt")); - - p.run(); - - String[] files = outFolder.list(); - - assertThat(Arrays.asList(files), - containsInAnyOrder("output-00000-of-00002.txt", - "output-00001-of-00002.txt")); - } - - @Test public void testWriteNamed() { { PTransform<PCollection<String>, PDone> transform1 = http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/639e9d95/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunnerTest.java ---------------------------------------------------------------------- diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunnerTest.java index c7175cb..c5f2d3f 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunnerTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineRunnerTest.java @@ -453,16 +453,12 @@ public class DataflowPipelineRunnerTest { @Test public void testNonGcsFilePathInWriteFailure() throws IOException { - ArgumentCaptor<Job> jobCaptor = ArgumentCaptor.forClass(Job.class); - - Pipeline p = buildDataflowPipeline(buildPipelineOptions(jobCaptor)); - p.apply(TextIO.Read.named("ReadMyGcsFile").from("gs://bucket/object")) - .apply(TextIO.Write.named("WriteMyNonGcsFile").to("/tmp/file")); + Pipeline p = buildDataflowPipeline(buildPipelineOptions()); + PCollection<String> pc = p.apply(TextIO.Read.named("ReadMyGcsFile").from("gs://bucket/object")); thrown.expect(IllegalArgumentException.class); thrown.expectMessage(containsString("expected a valid 'gs://' path but was given")); - p.run(); - assertValidJob(jobCaptor.getValue()); + pc.apply(TextIO.Write.named("WriteMyNonGcsFile").to("/tmp/file")); } @Test @@ -482,17 +478,12 @@ public class DataflowPipelineRunnerTest { @Test public void testMultiSlashGcsFileWritePath() throws IOException { - ArgumentCaptor<Job> jobCaptor = ArgumentCaptor.forClass(Job.class); - - Pipeline p = buildDataflowPipeline(buildPipelineOptions(jobCaptor)); - p.apply(TextIO.Read.named("ReadMyGcsFile").from("gs://bucket/object")) - .apply(TextIO.Write.named("WriteInvalidGcsFile") - .to("gs://bucket/tmp//file")); + Pipeline p = buildDataflowPipeline(buildPipelineOptions()); + PCollection<String> pc = p.apply(TextIO.Read.named("ReadMyGcsFile").from("gs://bucket/object")); thrown.expect(IllegalArgumentException.class); thrown.expectMessage("consecutive slashes"); - p.run(); - assertValidJob(jobCaptor.getValue()); + pc.apply(TextIO.Write.named("WriteInvalidGcsFile").to("gs://bucket/tmp//file")); } @Test http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/639e9d95/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslatorTest.java ---------------------------------------------------------------------- diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslatorTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslatorTest.java index b9c94ad..72090a0 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslatorTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DataflowPipelineTranslatorTest.java @@ -403,7 +403,7 @@ public class DataflowPipelineTranslatorTest { pipeline.apply(TextIO.Read.named("ReadMyFile").from("gs://bucket/in")) .apply(ParDo.of(new NoOpFn())) .apply(new EmbeddedTransform(predefinedStep.clone())) - .apply(TextIO.Write.named("WriteMyFile").to("gs://bucket/out")); + .apply(ParDo.of(new NoOpFn())); Job job = translator.translate( pipeline, pipeline.getRunner(), Collections.<DataflowPackage>emptyList()).getJob(); @@ -456,7 +456,7 @@ public class DataflowPipelineTranslatorTest { Job job = translator.translate( pipeline, pipeline.getRunner(), Collections.<DataflowPackage>emptyList()).getJob(); - assertEquals(3, job.getSteps().size()); + assertEquals(13, job.getSteps().size()); Step step = job.getSteps().get(1); assertEquals(stepName, getString(step.getProperties(), PropertyNames.USER_NAME)); return step; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/639e9d95/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunnerTest.java ---------------------------------------------------------------------- diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunnerTest.java index 904e4bb..4a0f91c 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunnerTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunnerTest.java @@ -16,35 +16,48 @@ package com.google.cloud.dataflow.sdk.runners; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.isA; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import com.google.cloud.dataflow.sdk.Pipeline; import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.io.ShardNameTemplate; +import com.google.cloud.dataflow.sdk.io.TextIO; import com.google.cloud.dataflow.sdk.options.DirectPipelineOptions; import com.google.cloud.dataflow.sdk.options.PipelineOptions; import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; import com.google.cloud.dataflow.sdk.transforms.Create; import com.google.cloud.dataflow.sdk.transforms.DoFn; import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.common.io.Files; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import java.io.File; import java.io.InputStream; import java.io.OutputStream; import java.io.Serializable; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; /** Tests for {@link DirectPipelineRunner}. */ @RunWith(JUnit4.class) public class DirectPipelineRunnerTest implements Serializable { - - @Rule - public transient ExpectedException expectedException = ExpectedException.none(); + @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); + @Rule public ExpectedException expectedException = ExpectedException.none(); @Test public void testToString() { @@ -54,6 +67,7 @@ public class DirectPipelineRunnerTest implements Serializable { runner.toString()); } + /** A {@link Coder} that fails during decoding. */ private static class CrashingCoder<T> extends AtomicCoder<T> { @Override public void encode(T value, OutputStream stream, Context context) throws CoderException { @@ -68,18 +82,21 @@ public class DirectPipelineRunnerTest implements Serializable { } } + /** A {@link DoFn} that outputs {@code 'hello'}. */ + private static class HelloDoFn extends DoFn<Integer, String> { + @Override + public void processElement(DoFn<Integer, String>.ProcessContext c) throws Exception { + c.output("hello"); + } + } + @Test public void testCoderException() { DirectPipeline pipeline = DirectPipeline.createForTest(); pipeline .apply("CreateTestData", Create.of(42)) - .apply("CrashDuringCoding", ParDo.of(new DoFn<Integer, String>() { - @Override - public void processElement(ProcessContext context) { - context.output("hello"); - } - })) + .apply("CrashDuringCoding", ParDo.of(new HelloDoFn())) .setCoder(new CrashingCoder<String>()); expectedException.expect(RuntimeException.class); @@ -92,4 +109,49 @@ public class DirectPipelineRunnerTest implements Serializable { DirectPipelineOptions options = PipelineOptionsFactory.create().as(DirectPipelineOptions.class); assertNull(options.getDirectPipelineRunnerRandomSeed()); } + + @Test + public void testTextIOWriteWithDefaultShardingStrategy() throws Exception { + String prefix = IOChannelUtils.resolve(Files.createTempDir().toString(), "output"); + Pipeline p = DirectPipeline.createForTest(); + String[] expectedElements = new String[]{ "a", "b", "c", "d", "e", "f", "g", "h", "i" }; + p.apply(Create.of(expectedElements)) + .apply(TextIO.Write.to(prefix).withSuffix("txt")); + p.run(); + + String filename = + IOChannelUtils.constructName(prefix, ShardNameTemplate.INDEX_OF_MAX, ".txt", 0, 1); + List<String> fileContents = + Files.readLines(new File(filename), StandardCharsets.UTF_8); + // Ensure that each file got at least one record + assertFalse(fileContents.isEmpty()); + + assertThat(fileContents, containsInAnyOrder(expectedElements)); + } + + @Test + public void testTextIOWriteWithLimitedNumberOfShards() throws Exception { + final int numShards = 3; + String prefix = IOChannelUtils.resolve(Files.createTempDir().toString(), "shardedOutput"); + Pipeline p = DirectPipeline.createForTest(); + String[] expectedElements = new String[]{ "a", "b", "c", "d", "e", "f", "g", "h", "i" }; + p.apply(Create.of(expectedElements)) + .apply(TextIO.Write.to(prefix).withNumShards(numShards).withSuffix("txt")); + p.run(); + + List<String> allContents = new ArrayList<>(); + for (int i = 0; i < numShards; ++i) { + String shardFileName = + IOChannelUtils.constructName(prefix, ShardNameTemplate.INDEX_OF_MAX, ".txt", i, 3); + List<String> shardFileContents = + Files.readLines(new File(shardFileName), StandardCharsets.UTF_8); + + // Ensure that each file got at least one record + assertFalse(shardFileContents.isEmpty()); + + allContents.addAll(shardFileContents); + } + + assertThat(allContents, containsInAnyOrder(expectedElements)); + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/639e9d95/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/TransformTreeTest.java ---------------------------------------------------------------------- diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/TransformTreeTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/TransformTreeTest.java index f1b7cd7..68e1db1 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/TransformTreeTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/TransformTreeTest.java @@ -28,6 +28,7 @@ import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.coders.VoidCoder; import com.google.cloud.dataflow.sdk.io.Read; import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.io.Write; import com.google.cloud.dataflow.sdk.transforms.Count; import com.google.cloud.dataflow.sdk.transforms.Create; import com.google.cloud.dataflow.sdk.transforms.PTransform; @@ -133,9 +134,12 @@ public class TransformTreeTest { assertTrue(visited.add(TransformsSeen.SAMPLE_ANY)); assertNotNull(node.getEnclosingNode()); assertTrue(node.isCompositeNode()); + } else if (transform instanceof Write.Bound) { + assertTrue(visited.add(TransformsSeen.WRITE)); + assertNotNull(node.getEnclosingNode()); + assertTrue(node.isCompositeNode()); } assertThat(transform, not(instanceOf(Read.Bounded.class))); - assertThat(transform, not(instanceOf(TextIO.Write.Bound.class))); } @Override @@ -151,10 +155,9 @@ public class TransformTreeTest { PTransform<?, ?> transform = node.getTransform(); // Pick is a composite, should not be visited here. assertThat(transform, not(instanceOf(Sample.SampleAny.class))); + assertThat(transform, not(instanceOf(Write.Bound.class))); if (transform instanceof Read.Bounded) { assertTrue(visited.add(TransformsSeen.READ)); - } else if (transform instanceof TextIO.Write.Bound) { - assertTrue(visited.add(TransformsSeen.WRITE)); } }
