[flink] improve InputFormat wrapper and ReadSourceITCase
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/6eac35e8 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/6eac35e8 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/6eac35e8 Branch: refs/heads/master Commit: 6eac35e81e93c25da4668fc1b0d30f7c942383f0 Parents: 7646384 Author: Maximilian Michels <m...@apache.org> Authored: Wed Mar 30 16:43:04 2016 +0200 Committer: Maximilian Michels <m...@apache.org> Committed: Mon Apr 18 16:36:43 2016 +0200 ---------------------------------------------------------------------- .../translation/wrappers/SourceInputFormat.java | 83 +++++++-------- .../beam/runners/flink/ReadSourceITCase.java | 100 ++----------------- 2 files changed, 43 insertions(+), 140 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6eac35e8/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java index 26e6297..4b11abc 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java @@ -23,20 +23,20 @@ import org.apache.beam.sdk.options.PipelineOptions; import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.api.common.io.DefaultInputSplitAssigner; import org.apache.flink.api.common.io.InputFormat; import org.apache.flink.api.common.io.statistics.BaseStatistics; import org.apache.flink.configuration.Configuration; -import org.apache.flink.core.io.InputSplit; import org.apache.flink.core.io.InputSplitAssigner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; import java.util.ArrayList; import java.util.List; + /** * A Flink {@link org.apache.flink.api.common.io.InputFormat} that wraps a * Dataflow {@link org.apache.beam.sdk.io.Source}. @@ -45,37 +45,40 @@ public class SourceInputFormat<T> implements InputFormat<T, SourceInputSplit<T>> private static final Logger LOG = LoggerFactory.getLogger(SourceInputFormat.class); private final BoundedSource<T> initialSource; + private transient PipelineOptions options; + private final byte[] serializedOptions; - private BoundedSource.BoundedReader<T> reader = null; - private boolean reachedEnd = true; + private transient BoundedSource.BoundedReader<T> reader = null; + private boolean inputAvailable = true; public SourceInputFormat(BoundedSource<T> initialSource, PipelineOptions options) { this.initialSource = initialSource; this.options = options; - } - private void writeObject(ObjectOutputStream out) - throws IOException, ClassNotFoundException { - out.defaultWriteObject(); - ObjectMapper mapper = new ObjectMapper(); - mapper.writeValue(out, options); - } + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try { + new ObjectMapper().writeValue(baos, options); + serializedOptions = baos.toByteArray(); + } catch (Exception e) { + throw new RuntimeException("Couldn't serialize PipelineOptions.", e); + } - private void readObject(ObjectInputStream in) - throws IOException, ClassNotFoundException { - in.defaultReadObject(); - ObjectMapper mapper = new ObjectMapper(); - options = mapper.readValue(in, PipelineOptions.class); } @Override - public void configure(Configuration configuration) {} + public void configure(Configuration configuration) { + try { + options = new ObjectMapper().readValue(serializedOptions, PipelineOptions.class); + } catch (IOException e) { + throw new RuntimeException("Couldn't deserialize the PipelineOptions.", e); + } + } @Override public void open(SourceInputSplit<T> sourceInputSplit) throws IOException { reader = ((BoundedSource<T>) sourceInputSplit.getSource()).createReader(options); - reachedEnd = false; + inputAvailable = reader.start(); } @Override @@ -87,7 +90,6 @@ public class SourceInputFormat<T> implements InputFormat<T, SourceInputSplit<T>> @Override public long getTotalInputSize() { return estimatedSize; - } @Override @@ -110,17 +112,15 @@ public class SourceInputFormat<T> implements InputFormat<T, SourceInputSplit<T>> @Override @SuppressWarnings("unchecked") public SourceInputSplit<T>[] createInputSplits(int numSplits) throws IOException { - long desiredSizeBytes; try { - desiredSizeBytes = initialSource.getEstimatedSizeBytes(options) / numSplits; - List<? extends Source<T>> shards = initialSource.splitIntoBundles(desiredSizeBytes, - options); - List<SourceInputSplit<T>> splits = new ArrayList<>(); - int splitCount = 0; - for (Source<T> shard: shards) { - splits.add(new SourceInputSplit<>(shard, splitCount++)); + long desiredSizeBytes = initialSource.getEstimatedSizeBytes(options) / numSplits; + List<? extends Source<T>> shards = initialSource.splitIntoBundles(desiredSizeBytes, options); + int numShards = shards.size(); + SourceInputSplit<T>[] sourceInputSplits = new SourceInputSplit[numShards]; + for (int i = 0; i < numShards; i++) { + sourceInputSplits[i] = new SourceInputSplit<>(shards.get(i), i); } - return splits.toArray(new SourceInputSplit[splits.size()]); + return sourceInputSplits; } catch (Exception e) { throw new IOException("Could not create input splits from Source.", e); } @@ -128,33 +128,24 @@ public class SourceInputFormat<T> implements InputFormat<T, SourceInputSplit<T>> @Override public InputSplitAssigner getInputSplitAssigner(final SourceInputSplit[] sourceInputSplits) { - return new InputSplitAssigner() { - private int index = 0; - private final SourceInputSplit[] splits = sourceInputSplits; - @Override - public InputSplit getNextInputSplit(String host, int taskId) { - if (index < splits.length) { - return splits[index++]; - } else { - return null; - } - } - }; + return new DefaultInputSplitAssigner(sourceInputSplits); } @Override public boolean reachedEnd() throws IOException { - return reachedEnd; + return !inputAvailable; } @Override public T nextRecord(T t) throws IOException { - - reachedEnd = !reader.advance(); - if (!reachedEnd) { - return reader.getCurrent(); + if (inputAvailable) { + final T current = reader.getCurrent(); + // advance reader to have a record ready next time + inputAvailable = reader.advance(); + return current; } + return null; } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6eac35e8/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java index bcad6f1..4f63925 100644 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java @@ -23,21 +23,14 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.io.TextIO; -import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.PCollection; import com.google.common.base.Joiner; -import com.google.common.base.Preconditions; import org.apache.flink.test.util.JavaProgramTestBase; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - - public class ReadSourceITCase extends JavaProgramTestBase { protected String resultPath; @@ -45,12 +38,13 @@ public class ReadSourceITCase extends JavaProgramTestBase { public ReadSourceITCase(){ } - static final String[] EXPECTED_RESULT = new String[] { - "1", "2", "3", "4", "5", "6", "7", "8", "9"}; + private static final String[] EXPECTED_RESULT = new String[] { + "0", "1", "2", "3", "4", "5", "6", "7", "8", "9"}; @Override protected void preSubmit() throws Exception { resultPath = getTempDirPath("result"); + System.out.println(resultPath); } @Override @@ -68,8 +62,8 @@ public class ReadSourceITCase extends JavaProgramTestBase { Pipeline p = FlinkTestPipeline.createForBatch(); PCollection<String> result = p - .apply(Read.from(new ReadSource(1, 10))) - .apply(ParDo.of(new DoFn<Integer, String>() { + .apply(CountingInput.upTo(10)) + .apply(ParDo.of(new DoFn<Long, String>() { @Override public void processElement(ProcessContext c) throws Exception { c.output(c.element().toString()); @@ -77,90 +71,8 @@ public class ReadSourceITCase extends JavaProgramTestBase { })); result.apply(TextIO.Write.to(resultPath)); - p.run(); - } - - private static class ReadSource extends BoundedSource<Integer> { - final int from; - final int to; - - ReadSource(int from, int to) { - this.from = from; - this.to = to; - } - - @Override - public List<ReadSource> splitIntoBundles(long desiredShardSizeBytes, PipelineOptions options) - throws Exception { - List<ReadSource> res = new ArrayList<>(); - FlinkPipelineOptions flinkOptions = options.as(FlinkPipelineOptions.class); - int numWorkers = flinkOptions.getParallelism(); - Preconditions.checkArgument(numWorkers > 0, "Number of workers should be larger than 0."); - - float step = 1.0f * (to - from) / numWorkers; - for (int i = 0; i < numWorkers; ++i) { - res.add(new ReadSource(Math.round(from + i * step), Math.round(from + (i + 1) * step))); - } - return res; - } - - @Override - public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { - return 8 * (to - from); - } - - @Override - public boolean producesSortedKeys(PipelineOptions options) throws Exception { - return true; - } - - @Override - public BoundedReader<Integer> createReader(PipelineOptions options) throws IOException { - return new RangeReader(this); - } - - @Override - public void validate() {} - - @Override - public Coder<Integer> getDefaultOutputCoder() { - return BigEndianIntegerCoder.of(); - } - - private class RangeReader extends BoundedReader<Integer> { - private int current; - - public RangeReader(ReadSource source) { - this.current = source.from - 1; - } - - @Override - public boolean start() throws IOException { - return true; - } - - @Override - public boolean advance() throws IOException { - current++; - return (current < to); - } - - @Override - public Integer getCurrent() { - return current; - } - - @Override - public void close() throws IOException { - // Nothing - } - - @Override - public BoundedSource<Integer> getCurrentSource() { - return ReadSource.this; - } - } + p.run(); } }