Fix bug where values written to the output in DoFn#startBundle and DoFn#finishBundle were being ignored. Introduced in 62830a0.
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/76815589 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/76815589 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/76815589 Branch: refs/heads/master Commit: 76815589f5d4b96868b8438f1820d17e0a5822ab Parents: 27349ad Author: Tom White <t...@cloudera.com> Authored: Tue Jul 14 16:44:15 2015 +0100 Committer: Tom White <t...@cloudera.com> Committed: Thu Mar 10 11:15:15 2016 +0000 ---------------------------------------------------------------------- .../cloudera/dataflow/spark/DoFnFunction.java | 2 +- .../dataflow/spark/SparkProcessContext.java | 22 +++++--- .../cloudera/dataflow/spark/DoFnOutputTest.java | 57 ++++++++++++++++++++ 3 files changed, 73 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/76815589/runners/spark/src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java index ae3dd79..542f2ec 100644 --- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java +++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/DoFnFunction.java @@ -51,8 +51,8 @@ class DoFnFunction<I, O> implements FlatMapFunction<Iterator<I>, O> { @Override public Iterable<O> call(Iterator<I> iter) throws Exception { ProcCtxt ctxt = new ProcCtxt(mFunction, mRuntimeContext, mSideInputs); - mFunction.startBundle(ctxt); ctxt.setup(); + mFunction.startBundle(ctxt); return ctxt.getOutputIterable(iter, mFunction); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/76815589/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkProcessContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkProcessContext.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkProcessContext.java index d0e9d6a..bda838c 100644 --- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkProcessContext.java +++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkProcessContext.java @@ -211,10 +211,12 @@ abstract class SparkProcessContext<I, O, V> extends DoFn<I, O>.ProcessContext { private final Iterator<I> inputIterator; private final DoFn<I, O> doFn; private Iterator<V> outputIterator; + private boolean calledFinish = false; public ProcCtxtIterator(Iterator<I> iterator, DoFn<I, O> doFn) { this.inputIterator = iterator; this.doFn = doFn; + this.outputIterator = getOutputIterator(); } @Override @@ -225,10 +227,9 @@ abstract class SparkProcessContext<I, O, V> extends DoFn<I, O>.ProcessContext { // collection only holds the output values for each call to processElement, rather // than for the whole partition (which would use too much memory). while (true) { - if (outputIterator != null && outputIterator.hasNext()) { + if (outputIterator.hasNext()) { return outputIterator.next(); - } - if (inputIterator.hasNext()) { + } else if (inputIterator.hasNext()) { clearOutput(); element = inputIterator.next(); try { @@ -239,10 +240,17 @@ abstract class SparkProcessContext<I, O, V> extends DoFn<I, O>.ProcessContext { outputIterator = getOutputIterator(); continue; // try to consume outputIterator from start of loop } else { - try { - doFn.finishBundle(SparkProcessContext.this); - } catch (Exception e) { - throw new IllegalStateException(e); + // no more input to consume, but finishBundle can produce more output + if (!calledFinish) { + clearOutput(); + try { + calledFinish = true; + doFn.finishBundle(SparkProcessContext.this); + } catch (Exception e) { + throw new IllegalStateException(e); + } + outputIterator = getOutputIterator(); + continue; // try to consume outputIterator from start of loop } return endOfData(); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/76815589/runners/spark/src/test/java/com/cloudera/dataflow/spark/DoFnOutputTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/com/cloudera/dataflow/spark/DoFnOutputTest.java b/runners/spark/src/test/java/com/cloudera/dataflow/spark/DoFnOutputTest.java new file mode 100644 index 0000000..2b0947f --- /dev/null +++ b/runners/spark/src/test/java/com/cloudera/dataflow/spark/DoFnOutputTest.java @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package com.cloudera.dataflow.spark; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.PCollection; +import java.io.Serializable; +import org.junit.Test; + +public class DoFnOutputTest implements Serializable { + @Test + public void test() throws Exception { + SparkPipelineOptions options = SparkPipelineOptionsFactory.create(); + options.setRunner(SparkPipelineRunner.class); + Pipeline pipeline = Pipeline.create(options); + + PCollection<String> strings = pipeline.apply(Create.of("a")); + // Test that values written from startBundle() and finishBundle() are written to + // the output + PCollection<String> output = strings.apply(ParDo.of(new DoFn<String, String>() { + @Override + public void startBundle(Context c) throws Exception { + c.output("start"); + } + @Override + public void processElement(ProcessContext c) throws Exception { + c.output(c.element()); + } + @Override + public void finishBundle(Context c) throws Exception { + c.output("finish"); + } + })); + + DataflowAssert.that(output).containsInAnyOrder("start", "a", "finish"); + + EvaluationResult res = SparkPipelineRunner.create().run(pipeline); + res.close(); + } +}