Fix bug where values written to the output in DoFn#startBundle and 
DoFn#finishBundle
were being ignored. Introduced in 62830a0.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/76815589
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/76815589
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/76815589

Branch: refs/heads/master
Commit: 76815589f5d4b96868b8438f1820d17e0a5822ab
Parents: 27349ad
Author: Tom White <t...@cloudera.com>
Authored: Tue Jul 14 16:44:15 2015 +0100
Committer: Tom White <t...@cloudera.com>
Committed: Thu Mar 10 11:15:15 2016 +0000

----------------------------------------------------------------------
 .../cloudera/dataflow/spark/DoFnFunction.java   |  2 +-
 .../dataflow/spark/SparkProcessContext.java     | 22 +++++---
 .../cloudera/dataflow/spark/DoFnOutputTest.java | 57 ++++++++++++++++++++
 3 files changed, 73 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/76815589/runners/spark/src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java 
b/runners/spark/src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java
index ae3dd79..542f2ec 100644
--- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java
+++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java
@@ -51,8 +51,8 @@ class DoFnFunction<I, O> implements 
FlatMapFunction<Iterator<I>, O> {
   @Override
   public Iterable<O> call(Iterator<I> iter) throws Exception {
     ProcCtxt ctxt = new ProcCtxt(mFunction, mRuntimeContext, mSideInputs);
-    mFunction.startBundle(ctxt);
     ctxt.setup();
+    mFunction.startBundle(ctxt);
     return ctxt.getOutputIterable(iter, mFunction);
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/76815589/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkProcessContext.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkProcessContext.java
 
b/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkProcessContext.java
index d0e9d6a..bda838c 100644
--- 
a/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkProcessContext.java
+++ 
b/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkProcessContext.java
@@ -211,10 +211,12 @@ abstract class SparkProcessContext<I, O, V> extends 
DoFn<I, O>.ProcessContext {
     private final Iterator<I> inputIterator;
     private final DoFn<I, O> doFn;
     private Iterator<V> outputIterator;
+    private boolean calledFinish = false;
 
     public ProcCtxtIterator(Iterator<I> iterator, DoFn<I, O> doFn) {
       this.inputIterator = iterator;
       this.doFn = doFn;
+      this.outputIterator = getOutputIterator();
     }
 
     @Override
@@ -225,10 +227,9 @@ abstract class SparkProcessContext<I, O, V> extends 
DoFn<I, O>.ProcessContext {
       // collection only holds the output values for each call to 
processElement, rather
       // than for the whole partition (which would use too much memory).
       while (true) {
-        if (outputIterator != null && outputIterator.hasNext()) {
+        if (outputIterator.hasNext()) {
           return outputIterator.next();
-        }
-        if (inputIterator.hasNext()) {
+        } else if (inputIterator.hasNext()) {
           clearOutput();
           element = inputIterator.next();
           try {
@@ -239,10 +240,17 @@ abstract class SparkProcessContext<I, O, V> extends 
DoFn<I, O>.ProcessContext {
           outputIterator = getOutputIterator();
           continue; // try to consume outputIterator from start of loop
         } else {
-          try {
-            doFn.finishBundle(SparkProcessContext.this);
-          } catch (Exception e) {
-            throw new IllegalStateException(e);
+          // no more input to consume, but finishBundle can produce more output
+          if (!calledFinish) {
+            clearOutput();
+            try {
+              calledFinish = true;
+              doFn.finishBundle(SparkProcessContext.this);
+            } catch (Exception e) {
+              throw new IllegalStateException(e);
+            }
+            outputIterator = getOutputIterator();
+            continue; // try to consume outputIterator from start of loop
           }
           return endOfData();
         }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/76815589/runners/spark/src/test/java/com/cloudera/dataflow/spark/DoFnOutputTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/com/cloudera/dataflow/spark/DoFnOutputTest.java 
b/runners/spark/src/test/java/com/cloudera/dataflow/spark/DoFnOutputTest.java
new file mode 100644
index 0000000..2b0947f
--- /dev/null
+++ 
b/runners/spark/src/test/java/com/cloudera/dataflow/spark/DoFnOutputTest.java
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved.
+ *
+ * Cloudera, Inc. licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"). You may not use this file except in
+ * compliance with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+ * CONDITIONS OF ANY KIND, either express or implied. See the License for
+ * the specific language governing permissions and limitations under the
+ * License.
+ */
+
+package com.cloudera.dataflow.spark;
+
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.testing.DataflowAssert;
+import com.google.cloud.dataflow.sdk.transforms.Create;
+import com.google.cloud.dataflow.sdk.transforms.DoFn;
+import com.google.cloud.dataflow.sdk.transforms.ParDo;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import java.io.Serializable;
+import org.junit.Test;
+
+public class DoFnOutputTest implements Serializable {
+  @Test
+  public void test() throws Exception {
+    SparkPipelineOptions options = SparkPipelineOptionsFactory.create();
+    options.setRunner(SparkPipelineRunner.class);
+    Pipeline pipeline = Pipeline.create(options);
+
+    PCollection<String> strings = pipeline.apply(Create.of("a"));
+    // Test that values written from startBundle() and finishBundle() are 
written to
+    // the output
+    PCollection<String> output = strings.apply(ParDo.of(new DoFn<String, 
String>() {
+      @Override
+      public void startBundle(Context c) throws Exception {
+        c.output("start");
+      }
+      @Override
+      public void processElement(ProcessContext c) throws Exception {
+        c.output(c.element());
+      }
+      @Override
+      public void finishBundle(Context c) throws Exception {
+        c.output("finish");
+      }
+    }));
+
+    DataflowAssert.that(output).containsInAnyOrder("start", "a", "finish");
+
+    EvaluationResult res = SparkPipelineRunner.create().run(pipeline);
+    res.close();
+  }
+}

Reply via email to