[flink] improve InputFormat wrapper and ReadSourceITCase

Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/6eac35e8
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/6eac35e8
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/6eac35e8

Branch: refs/heads/master
Commit: 6eac35e81e93c25da4668fc1b0d30f7c942383f0
Parents: 7646384
Author: Maximilian Michels <m...@apache.org>
Authored: Wed Mar 30 16:43:04 2016 +0200
Committer: Maximilian Michels <m...@apache.org>
Committed: Mon Apr 18 16:36:43 2016 +0200

----------------------------------------------------------------------
 .../translation/wrappers/SourceInputFormat.java |  83 +++++++--------
 .../beam/runners/flink/ReadSourceITCase.java    | 100 ++-----------------
 2 files changed, 43 insertions(+), 140 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6eac35e8/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java
----------------------------------------------------------------------
diff --git 
a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java
 
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java
index 26e6297..4b11abc 100644
--- 
a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java
+++ 
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java
@@ -23,20 +23,20 @@ import org.apache.beam.sdk.options.PipelineOptions;
 
 import com.fasterxml.jackson.databind.ObjectMapper;
 
+import org.apache.flink.api.common.io.DefaultInputSplitAssigner;
 import org.apache.flink.api.common.io.InputFormat;
 import org.apache.flink.api.common.io.statistics.BaseStatistics;
 import org.apache.flink.configuration.Configuration;
-import org.apache.flink.core.io.InputSplit;
 import org.apache.flink.core.io.InputSplitAssigner;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.ByteArrayOutputStream;
 import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
 import java.util.ArrayList;
 import java.util.List;
 
+
 /**
  * A Flink {@link org.apache.flink.api.common.io.InputFormat} that wraps a
  * Dataflow {@link org.apache.beam.sdk.io.Source}.
@@ -45,37 +45,40 @@ public class SourceInputFormat<T> implements InputFormat<T, 
SourceInputSplit<T>>
   private static final Logger LOG = 
LoggerFactory.getLogger(SourceInputFormat.class);
 
   private final BoundedSource<T> initialSource;
+
   private transient PipelineOptions options;
+  private final byte[] serializedOptions;
 
-  private BoundedSource.BoundedReader<T> reader = null;
-  private boolean reachedEnd = true;
+  private transient BoundedSource.BoundedReader<T> reader = null;
+  private boolean inputAvailable = true;
 
   public SourceInputFormat(BoundedSource<T> initialSource, PipelineOptions 
options) {
     this.initialSource = initialSource;
     this.options = options;
-  }
 
-  private void writeObject(ObjectOutputStream out)
-      throws IOException, ClassNotFoundException {
-    out.defaultWriteObject();
-    ObjectMapper mapper = new ObjectMapper();
-    mapper.writeValue(out, options);
-  }
+    ByteArrayOutputStream baos = new ByteArrayOutputStream();
+    try {
+      new ObjectMapper().writeValue(baos, options);
+      serializedOptions = baos.toByteArray();
+    } catch (Exception e) {
+      throw new RuntimeException("Couldn't serialize PipelineOptions.", e);
+    }
 
-  private void readObject(ObjectInputStream in)
-      throws IOException, ClassNotFoundException {
-    in.defaultReadObject();
-    ObjectMapper mapper = new ObjectMapper();
-    options = mapper.readValue(in, PipelineOptions.class);
   }
 
   @Override
-  public void configure(Configuration configuration) {}
+  public void configure(Configuration configuration) {
+    try {
+      options = new ObjectMapper().readValue(serializedOptions, 
PipelineOptions.class);
+    } catch (IOException e) {
+      throw new RuntimeException("Couldn't deserialize the PipelineOptions.", 
e);
+    }
+  }
 
   @Override
   public void open(SourceInputSplit<T> sourceInputSplit) throws IOException {
     reader = ((BoundedSource<T>) 
sourceInputSplit.getSource()).createReader(options);
-    reachedEnd = false;
+    inputAvailable = reader.start();
   }
 
   @Override
@@ -87,7 +90,6 @@ public class SourceInputFormat<T> implements InputFormat<T, 
SourceInputSplit<T>>
         @Override
         public long getTotalInputSize() {
           return estimatedSize;
-
         }
 
         @Override
@@ -110,17 +112,15 @@ public class SourceInputFormat<T> implements 
InputFormat<T, SourceInputSplit<T>>
   @Override
   @SuppressWarnings("unchecked")
   public SourceInputSplit<T>[] createInputSplits(int numSplits) throws 
IOException {
-    long desiredSizeBytes;
     try {
-      desiredSizeBytes = initialSource.getEstimatedSizeBytes(options) / 
numSplits;
-      List<? extends Source<T>> shards = 
initialSource.splitIntoBundles(desiredSizeBytes,
-          options);
-      List<SourceInputSplit<T>> splits = new ArrayList<>();
-      int splitCount = 0;
-      for (Source<T> shard: shards) {
-        splits.add(new SourceInputSplit<>(shard, splitCount++));
+      long desiredSizeBytes = initialSource.getEstimatedSizeBytes(options) / 
numSplits;
+      List<? extends Source<T>> shards = 
initialSource.splitIntoBundles(desiredSizeBytes, options);
+      int numShards = shards.size();
+      SourceInputSplit<T>[] sourceInputSplits = new 
SourceInputSplit[numShards];
+      for (int i = 0; i < numShards; i++) {
+        sourceInputSplits[i] = new SourceInputSplit<>(shards.get(i), i);
       }
-      return splits.toArray(new SourceInputSplit[splits.size()]);
+      return sourceInputSplits;
     } catch (Exception e) {
       throw new IOException("Could not create input splits from Source.", e);
     }
@@ -128,33 +128,24 @@ public class SourceInputFormat<T> implements 
InputFormat<T, SourceInputSplit<T>>
 
   @Override
   public InputSplitAssigner getInputSplitAssigner(final SourceInputSplit[] 
sourceInputSplits) {
-    return new InputSplitAssigner() {
-      private int index = 0;
-      private final SourceInputSplit[] splits = sourceInputSplits;
-      @Override
-      public InputSplit getNextInputSplit(String host, int taskId) {
-        if (index < splits.length) {
-          return splits[index++];
-        } else {
-          return null;
-        }
-      }
-    };
+    return new DefaultInputSplitAssigner(sourceInputSplits);
   }
 
 
   @Override
   public boolean reachedEnd() throws IOException {
-    return reachedEnd;
+    return !inputAvailable;
   }
 
   @Override
   public T nextRecord(T t) throws IOException {
-
-    reachedEnd = !reader.advance();
-    if (!reachedEnd) {
-      return reader.getCurrent();
+    if (inputAvailable) {
+      final T current = reader.getCurrent();
+      // advance reader to have a record ready next time
+      inputAvailable = reader.advance();
+      return current;
     }
+
     return null;
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6eac35e8/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java
----------------------------------------------------------------------
diff --git 
a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java
 
b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java
index bcad6f1..4f63925 100644
--- 
a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java
+++ 
b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/ReadSourceITCase.java
@@ -23,21 +23,14 @@ import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.io.BoundedSource;
 import org.apache.beam.sdk.io.Read;
 import org.apache.beam.sdk.io.TextIO;
-import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.values.PCollection;
 
 import com.google.common.base.Joiner;
-import com.google.common.base.Preconditions;
 
 import org.apache.flink.test.util.JavaProgramTestBase;
 
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
-
 public class ReadSourceITCase extends JavaProgramTestBase {
 
   protected String resultPath;
@@ -45,12 +38,13 @@ public class ReadSourceITCase extends JavaProgramTestBase {
   public ReadSourceITCase(){
   }
 
-  static final String[] EXPECTED_RESULT = new String[] {
-      "1", "2", "3", "4", "5", "6", "7", "8", "9"};
+  private static final String[] EXPECTED_RESULT = new String[] {
+     "0", "1", "2", "3", "4", "5", "6", "7", "8", "9"};
 
   @Override
   protected void preSubmit() throws Exception {
     resultPath = getTempDirPath("result");
+    System.out.println(resultPath);
   }
 
   @Override
@@ -68,8 +62,8 @@ public class ReadSourceITCase extends JavaProgramTestBase {
     Pipeline p = FlinkTestPipeline.createForBatch();
 
     PCollection<String> result = p
-        .apply(Read.from(new ReadSource(1, 10)))
-        .apply(ParDo.of(new DoFn<Integer, String>() {
+        .apply(CountingInput.upTo(10))
+        .apply(ParDo.of(new DoFn<Long, String>() {
           @Override
           public void processElement(ProcessContext c) throws Exception {
             c.output(c.element().toString());
@@ -77,90 +71,8 @@ public class ReadSourceITCase extends JavaProgramTestBase {
         }));
 
     result.apply(TextIO.Write.to(resultPath));
-    p.run();
-  }
-
 
-  private static class ReadSource extends BoundedSource<Integer> {
-    final int from;
-    final int to;
-
-    ReadSource(int from, int to) {
-      this.from = from;
-      this.to = to;
-    }
-
-    @Override
-    public List<ReadSource> splitIntoBundles(long desiredShardSizeBytes, 
PipelineOptions options)
-        throws Exception {
-      List<ReadSource> res = new ArrayList<>();
-      FlinkPipelineOptions flinkOptions = 
options.as(FlinkPipelineOptions.class);
-      int numWorkers = flinkOptions.getParallelism();
-      Preconditions.checkArgument(numWorkers > 0, "Number of workers should be 
larger than 0.");
-
-      float step = 1.0f * (to - from) / numWorkers;
-      for (int i = 0; i < numWorkers; ++i) {
-        res.add(new ReadSource(Math.round(from + i * step), Math.round(from + 
(i + 1) * step)));
-      }
-      return res;
-    }
-
-    @Override
-    public long getEstimatedSizeBytes(PipelineOptions options) throws 
Exception {
-      return 8 * (to - from);
-    }
-
-    @Override
-    public boolean producesSortedKeys(PipelineOptions options) throws 
Exception {
-      return true;
-    }
-
-    @Override
-    public BoundedReader<Integer> createReader(PipelineOptions options) throws 
IOException {
-      return new RangeReader(this);
-    }
-
-    @Override
-    public void validate() {}
-
-    @Override
-    public Coder<Integer> getDefaultOutputCoder() {
-      return BigEndianIntegerCoder.of();
-    }
-
-    private class RangeReader extends BoundedReader<Integer> {
-      private int current;
-
-      public RangeReader(ReadSource source) {
-        this.current = source.from - 1;
-      }
-
-      @Override
-      public boolean start() throws IOException {
-        return true;
-      }
-
-      @Override
-      public boolean advance() throws IOException {
-        current++;
-        return (current < to);
-      }
-
-      @Override
-      public Integer getCurrent() {
-        return current;
-      }
-
-      @Override
-      public void close() throws IOException {
-        // Nothing
-      }
-
-      @Override
-      public BoundedSource<Integer> getCurrentSource() {
-        return ReadSource.this;
-      }
-    }
+    p.run();
   }
 }
 

Reply via email to