johnyangk closed pull request #63: [NEMO-124] Support DoFn#output(tag, output)
URL: https://github.com/apache/incubator-nemo/pull/63
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/common/src/main/java/edu/snu/nemo/common/ContextImpl.java 
b/common/src/main/java/edu/snu/nemo/common/ContextImpl.java
index 823090ca7..d1774d3b6 100644
--- a/common/src/main/java/edu/snu/nemo/common/ContextImpl.java
+++ b/common/src/main/java/edu/snu/nemo/common/ContextImpl.java
@@ -25,14 +25,16 @@
  */
 public final class ContextImpl implements Transform.Context {
   private final Map sideInputs;
+  private final Map additionalTagOutputs;
   private String data;
 
   /**
    * Constructor of Context Implementation.
    * @param sideInputs side inputs.
    */
-  public ContextImpl(final Map sideInputs) {
+  public ContextImpl(final Map sideInputs, final Map additionalTagOutputs) {
     this.sideInputs = sideInputs;
+    this.additionalTagOutputs = additionalTagOutputs;
     this.data = null;
   }
 
@@ -41,6 +43,11 @@ public Map getSideInputs() {
     return this.sideInputs;
   }
 
+  @Override
+  public Map getAdditionalTagOutputs() {
+    return this.additionalTagOutputs;
+  }
+
   @Override
   public void setSerializedData(final String serializedData) {
     this.data = serializedData;
diff --git a/common/src/main/java/edu/snu/nemo/common/dag/Edge.java 
b/common/src/main/java/edu/snu/nemo/common/dag/Edge.java
index 49c0df5fb..2e6982778 100644
--- a/common/src/main/java/edu/snu/nemo/common/dag/Edge.java
+++ b/common/src/main/java/edu/snu/nemo/common/dag/Edge.java
@@ -53,7 +53,7 @@ public final String getId() {
   }
 
   /**
-   * @return the numeric ID of the edge.
+   * @return the numeric ID of the edge. (for edge id "edge-2", this method 
returns 2)
    */
   public final Integer getNumericId() {
     return Integer.parseInt(id.replaceAll("[^\\d.]", ""));
diff --git a/common/src/main/java/edu/snu/nemo/common/ir/OutputCollector.java 
b/common/src/main/java/edu/snu/nemo/common/ir/OutputCollector.java
index 4b1a6a2cd..b833d221c 100644
--- a/common/src/main/java/edu/snu/nemo/common/ir/OutputCollector.java
+++ b/common/src/main/java/edu/snu/nemo/common/ir/OutputCollector.java
@@ -37,5 +37,5 @@
    * @param dstVertexId destination vertex id.
    * @param output value.
    */
-  void emit(String dstVertexId, Object output);
+  <T> void emit(String dstVertexId, T output);
 }
diff --git 
a/common/src/main/java/edu/snu/nemo/common/ir/vertex/executionproperty/AdditionalTagOutputProperty.java
 
b/common/src/main/java/edu/snu/nemo/common/ir/vertex/executionproperty/AdditionalTagOutputProperty.java
new file mode 100644
index 000000000..dd99100d1
--- /dev/null
+++ 
b/common/src/main/java/edu/snu/nemo/common/ir/vertex/executionproperty/AdditionalTagOutputProperty.java
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *         http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.common.ir.vertex.executionproperty;
+
+import edu.snu.nemo.common.ir.executionproperty.VertexExecutionProperty;
+
+import java.util.HashMap;
+
+/**
+ * AdditionalOutput Execution Property for vertex that outputs additional 
outputs.
+ */
+public final class AdditionalTagOutputProperty extends 
VertexExecutionProperty<HashMap<String, String>> {
+  /**
+   * Constructor.
+   * @param value map of tag to IRVertex id.
+   */
+  private AdditionalTagOutputProperty(final HashMap<String, String> value) {
+    super(value);
+  }
+
+  /**
+   * Static method exposing constructor.
+   * @param value map of tag to IRVertex id.
+   * @return the newly created execution property.
+   */
+  public static AdditionalTagOutputProperty of(final HashMap<String, String> 
value) {
+    return new AdditionalTagOutputProperty(value);
+  }
+}
diff --git 
a/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java 
b/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java
index 448f76a55..db1927e63 100644
--- 
a/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java
+++ 
b/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java
@@ -61,6 +61,7 @@ default Object getTag() {
      * @return sideInputs.
      */
     Map getSideInputs();
+    Map getAdditionalTagOutputs();
 
     /**
      * Put serialized data to send to the executor.
diff --git a/common/src/test/java/edu/snu/nemo/common/ContextImplTest.java 
b/common/src/test/java/edu/snu/nemo/common/ContextImplTest.java
index 5a89d78a6..b98489e3c 100644
--- a/common/src/test/java/edu/snu/nemo/common/ContextImplTest.java
+++ b/common/src/test/java/edu/snu/nemo/common/ContextImplTest.java
@@ -33,16 +33,18 @@
 public class ContextImplTest {
   private Transform.Context context;
   private final Map sideInputs = new HashMap();
+  private final Map taggedOutputs = new HashMap();
 
   @Before
   public void setUp() {
     sideInputs.put("a", "b");
-    this.context = new ContextImpl(sideInputs);
+    this.context = new ContextImpl(sideInputs, taggedOutputs);
   }
 
   @Test
   public void testContextImpl() {
     assertEquals(this.sideInputs, this.context.getSideInputs());
+    assertEquals(this.taggedOutputs, this.context.getAdditionalTagOutputs());
 
     final String sampleText = "sample_text";
 
diff --git 
a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineVisitor.java
 
b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineVisitor.java
index b85057b5b..39f14bbbf 100644
--- 
a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineVisitor.java
+++ 
b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineVisitor.java
@@ -18,6 +18,7 @@
 import edu.snu.nemo.common.Pair;
 import edu.snu.nemo.common.ir.edge.executionproperty.DecoderProperty;
 import edu.snu.nemo.common.ir.edge.executionproperty.EncoderProperty;
+import 
edu.snu.nemo.common.ir.vertex.executionproperty.AdditionalTagOutputProperty;
 import edu.snu.nemo.common.ir.vertex.transform.Transform;
 import edu.snu.nemo.compiler.frontend.beam.coder.BeamDecoderFactory;
 import edu.snu.nemo.compiler.frontend.beam.coder.BeamEncoderFactory;
@@ -43,6 +44,7 @@
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.PCollectionViews;
 import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TupleTag;
 
 import java.util.HashMap;
 import java.util.List;
@@ -59,6 +61,7 @@
   // loopVertexStack keeps track of where the beam program is: whether it is 
inside a composite transform or it is not.
   private final Stack<LoopVertex> loopVertexStack;
   private final Map<PValue, Pair<BeamEncoderFactory, BeamDecoderFactory>> 
pValueToCoder;
+  private final Map<PValue, TupleTag> pValueToTag;
 
   /**
    * Constructor of the BEAM Visitor.
@@ -72,6 +75,7 @@ public NemoPipelineVisitor(final DAGBuilder<IRVertex, IREdge> 
builder, final Pip
     this.options = options;
     this.loopVertexStack = new Stack<>();
     this.pValueToCoder = new HashMap<>();
+    this.pValueToTag = new HashMap<>();
   }
 
   @Override
@@ -96,12 +100,8 @@ public void leaveCompositeTransform(final 
TransformHierarchy.Node beamNode) {
   public void visitPrimitiveTransform(final TransformHierarchy.Node beamNode) {
 //    Print if needed for development
 //    LOG.info("visitp " + beamNode.getTransform());
-    if (beamNode.getOutputs().size() > 1) {
-      throw new UnsupportedOperationException(beamNode.toString());
-    }
-
     final IRVertex irVertex =
-        convertToVertex(beamNode, builder, pValueToVertex, pValueToCoder, 
options, loopVertexStack);
+        convertToVertex(beamNode, builder, pValueToVertex, pValueToCoder, 
pValueToTag, options, loopVertexStack);
     beamNode.getOutputs().values().stream().filter(v -> v instanceof 
PCollection).map(v -> (PCollection) v)
         .forEach(output -> pValueToCoder.put(output,
             Pair.of(new BeamEncoderFactory(output.getCoder()), new 
BeamDecoderFactory(output.getCoder()))));
@@ -118,6 +118,22 @@ public void visitPrimitiveTransform(final 
TransformHierarchy.Node beamNode) {
           edge.setProperty(KeyExtractorProperty.of(new BeamKeyExtractor()));
           this.builder.connectVertices(edge);
         });
+
+    // This exclusively updates execution property of vertices with additional 
tagged outputs.
+    beamNode.getInputs().values().stream().filter(pValueToTag::containsKey)
+        .forEach(pValue -> {
+          final IRVertex src = pValueToVertex.get(pValue);
+          final TupleTag tag = pValueToTag.get(pValue);
+          final HashMap<String, String> tagToVertex = new HashMap<>();
+          tagToVertex.put(tag.getId(), irVertex.getId());
+          if 
(!src.getPropertyValue(AdditionalTagOutputProperty.class).isPresent()) {
+            src.setProperty(AdditionalTagOutputProperty.of(tagToVertex));
+          } else {
+            final HashMap<String, String> prev = 
src.getPropertyValue(AdditionalTagOutputProperty.class).get();
+            prev.putAll(tagToVertex);
+            src.setProperty(AdditionalTagOutputProperty.of(prev));
+          }
+        });
   }
 
   /**
@@ -127,6 +143,7 @@ public void visitPrimitiveTransform(final 
TransformHierarchy.Node beamNode) {
    * @param builder         the DAG builder to add the vertex to.
    * @param pValueToVertex  PValue to Vertex map.
    * @param pValueToCoder   PValue to EncoderFactory and DecoderFactory map.
+   * @param pValueToTag     PValue to Tag map.
    * @param options         pipeline options.
    * @param loopVertexStack Stack to get the current loop vertex that the 
operator vertex will be assigned to.
    * @param <I>             input type.
@@ -138,6 +155,7 @@ public void visitPrimitiveTransform(final 
TransformHierarchy.Node beamNode) {
                   final DAGBuilder<IRVertex, IREdge> builder,
                   final Map<PValue, IRVertex> pValueToVertex,
                   final Map<PValue, Pair<BeamEncoderFactory, 
BeamDecoderFactory>> pValueToCoder,
+                  final Map<PValue, TupleTag> pValueToTag,
                   final PipelineOptions options,
                   final Stack<LoopVertex> loopVertexStack) {
     final PTransform beamTransform = beamNode.getTransform();
@@ -184,6 +202,11 @@ public void visitPrimitiveTransform(final 
TransformHierarchy.Node beamNode) {
       final ParDo.MultiOutput<I, O> parDo = (ParDo.MultiOutput<I, O>) 
beamTransform;
       final DoTransform transform = new DoTransform(parDo.getFn(), options);
       irVertex = new OperatorVertex(transform);
+      if (parDo.getAdditionalOutputTags().size() > 0) {
+        beamNode.getOutputs().entrySet().stream()
+            .filter(kv -> !kv.getKey().equals(parDo.getMainOutputTag()))
+            .forEach(kv -> pValueToTag.put(kv.getValue(), kv.getKey()));
+      }
       builder.addVertex(irVertex, loopVertexStack);
       connectSideInputs(builder, parDo.getSideInputs(), pValueToVertex, 
pValueToCoder, irVertex);
     } else if (beamTransform instanceof Flatten.PCollections) {
diff --git 
a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/DoTransform.java
 
b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/DoTransform.java
index b56db4e04..9d7636793 100644
--- 
a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/DoTransform.java
+++ 
b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/DoTransform.java
@@ -72,7 +72,8 @@ public void prepare(final Context context, final 
OutputCollector<O> oc) {
     this.outputCollector = oc;
     this.startBundleContext = new StartBundleContext(doFn, serializedOptions);
     this.finishBundleContext = new FinishBundleContext(doFn, outputCollector, 
serializedOptions);
-    this.processContext = new ProcessContext(doFn, outputCollector, 
context.getSideInputs(), serializedOptions);
+    this.processContext = new ProcessContext(doFn, outputCollector,
+        context.getSideInputs(), context.getAdditionalTagOutputs(), 
serializedOptions);
     this.invoker = DoFnInvokers.invokerFor(doFn);
     invoker.invokeSetup();
     invoker.invokeStartBundle(startBundleContext);
@@ -192,6 +193,7 @@ public void output(final O output, final Instant instant, 
final BoundedWindow bo
     private I input;
     private final OutputCollector<O> outputCollector;
     private final Map sideInputs;
+    private final Map additionalOutputs;
     private final ObjectMapper mapper;
     private final PipelineOptions options;
 
@@ -201,15 +203,18 @@ public void output(final O output, final Instant instant, 
final BoundedWindow bo
      * @param fn                Dofn.
      * @param outputCollector   OutputCollector.
      * @param sideInputs        Map for SideInputs.
+     * @param additionalOutputs     Map for TaggedOutputs.
      * @param serializedOptions Options, serialized.
      */
     ProcessContext(final DoFn<I, O> fn,
                    final OutputCollector<O> outputCollector,
                    final Map sideInputs,
+                   final Map additionalOutputs,
                    final String serializedOptions) {
       fn.super();
       this.outputCollector = outputCollector;
       this.sideInputs = sideInputs;
+      this.additionalOutputs = additionalOutputs;
       this.mapper = new ObjectMapper();
       try {
         this.options = mapper.readValue(serializedOptions, 
PipelineOptions.class);
@@ -269,7 +274,7 @@ public void outputWithTimestamp(final O output, final 
Instant timestamp) {
 
     @Override
     public <T> void output(final TupleTag<T> tupleTag, final T t) {
-      throw new UnsupportedOperationException("output(TupleTag, T) in 
ProcessContext under DoTransform");
+      outputCollector.emit((String) additionalOutputs.get(tupleTag.getId()), 
t);
     }
 
     @Override
diff --git 
a/examples/beam/src/main/java/edu/snu/nemo/examples/beam/PartitionWordsByLength.java
 
b/examples/beam/src/main/java/edu/snu/nemo/examples/beam/PartitionWordsByLength.java
new file mode 100644
index 000000000..5613e3ca6
--- /dev/null
+++ 
b/examples/beam/src/main/java/edu/snu/nemo/examples/beam/PartitionWordsByLength.java
@@ -0,0 +1,96 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *         http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.examples.beam;
+
+import edu.snu.nemo.compiler.frontend.beam.NemoPipelineOptions;
+import edu.snu.nemo.compiler.frontend.beam.NemoPipelineRunner;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.FlatMapElements;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.*;
+
+import java.util.Arrays;
+
+/**
+ * Partition words by length example.
+ */
+public final class PartitionWordsByLength {
+  /**
+   * Private Constructor.
+   */
+  private PartitionWordsByLength() {
+  }
+
+  /**
+   * Main function for the MR BEAM program.
+   *
+   * @param args arguments.
+   */
+  public static void main(final String[] args) {
+    final String inputFilePath = args[0];
+    final String outputFilePath = args[1];
+    final PipelineOptions options = 
PipelineOptionsFactory.create().as(NemoPipelineOptions.class);
+    options.setRunner(NemoPipelineRunner.class);
+    options.setJobName("PartitionWordsByLength");
+
+    // {} here is required for preserving type information.
+    // Please see https://stackoverflow.com/a/48431397 for details.
+    final TupleTag<String> shortWordsTag = new TupleTag<String>() {
+    };
+    final TupleTag<Integer> longWordsTag = new TupleTag<Integer>() {
+    };
+    final TupleTag<String> veryLongWordsTag = new TupleTag<String>() {
+    };
+
+    final Pipeline p = Pipeline.create(options);
+    final PCollection<String> lines = GenericSourceSink.read(p, inputFilePath);
+
+    PCollectionTuple results = lines
+        .apply(FlatMapElements
+            .into(TypeDescriptors.strings())
+            .via(line -> Arrays.asList(line.split(" "))))
+        .apply(ParDo.of(new DoFn<String, String>() {
+          @ProcessElement
+          public void processElement(final ProcessContext c) {
+            String word = c.element();
+            if (word.length() < 5) {
+              c.output(shortWordsTag, word);
+            } else if (word.length() < 8) {
+              c.output(longWordsTag, word.length());
+            } else {
+              c.output(veryLongWordsTag, word);
+            }
+          }
+        }).withOutputTags(veryLongWordsTag, TupleTagList
+            .of(longWordsTag)
+            .and(shortWordsTag)));
+
+    PCollection<String> shortWords = results.get(shortWordsTag);
+    PCollection<String> longWordLengths = results
+        .get(longWordsTag)
+        .apply(MapElements.into(TypeDescriptors.strings()).via(i -> 
Integer.toString(i)));
+    PCollection<String> veryLongWords = results.get(veryLongWordsTag);
+
+    GenericSourceSink.write(shortWords, outputFilePath + "_short");
+    GenericSourceSink.write(longWordLengths, outputFilePath + "_long");
+    GenericSourceSink.write(veryLongWords, outputFilePath + "_very_long");
+    p.run();
+  }
+}
diff --git 
a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/PartitionWordsByLengthITCase.java
 
b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/PartitionWordsByLengthITCase.java
new file mode 100644
index 000000000..31c40ef90
--- /dev/null
+++ 
b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/PartitionWordsByLengthITCase.java
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *         http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.examples.beam;
+
+import edu.snu.nemo.client.JobLauncher;
+import edu.snu.nemo.common.test.ArgBuilder;
+import edu.snu.nemo.common.test.ExampleTestUtil;
+import edu.snu.nemo.examples.beam.policy.DefaultPolicyParallelismFive;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+
+/**
+ * Test PartitionWordByLength program with JobLauncher.
+ */
+@RunWith(PowerMockRunner.class)
+@PrepareForTest(JobLauncher.class)
+public final class PartitionWordsByLengthITCase {
+  private static final int TIMEOUT = 120000;
+  private static ArgBuilder builder;
+  private static final String fileBasePath = System.getProperty("user.dir") + 
"/../resources/";
+
+  private static final String inputFileName = "sample_input_tag";
+  private static final String outputFileName = "sample_output_tag";
+  private static final String testResourceFileName = "test_output_tag";
+  private static final String executorResourceFileName = fileBasePath + 
"beam_sample_executor_resources.json";
+  private static final String inputFilePath =  fileBasePath + inputFileName;
+  private static final String outputFilePath =  fileBasePath + outputFileName;
+
+  @Before
+  public void setUp() throws Exception {
+    builder = new ArgBuilder()
+      .addResourceJson(executorResourceFileName)
+      .addUserMain(PartitionWordsByLength.class.getCanonicalName())
+      .addUserArgs(inputFilePath, outputFilePath);
+  }
+
+  @After
+  public void tearDown() throws Exception {
+    try {
+      ExampleTestUtil.ensureOutputValidity(fileBasePath, outputFileName + 
"_short", testResourceFileName + "_short");
+      ExampleTestUtil.ensureOutputValidity(fileBasePath, outputFileName + 
"_long", testResourceFileName + "_long");
+      ExampleTestUtil.ensureOutputValidity(fileBasePath, outputFileName + 
"_very_long", testResourceFileName + "_very_long");
+    } finally {
+      ExampleTestUtil.deleteOutputFile(fileBasePath, outputFileName);
+    }
+  }
+
+  @Test (timeout = TIMEOUT)
+  public void test() throws Exception {
+    JobLauncher.main(builder
+      .addJobId(PartitionWordsByLength.class.getSimpleName())
+      
.addOptimizationPolicy(DefaultPolicyParallelismFive.class.getCanonicalName())
+      .build());
+  }
+}
diff --git a/examples/resources/sample_input_tag 
b/examples/resources/sample_input_tag
new file mode 100644
index 000000000..0cd417beb
--- /dev/null
+++ b/examples/resources/sample_input_tag
@@ -0,0 +1,7 @@
+foo
+bar
+foobar
+barbaz
+foobarbaz
+ipsumlorem
+qux
\ No newline at end of file
diff --git a/examples/resources/test_output_tag_long 
b/examples/resources/test_output_tag_long
new file mode 100644
index 000000000..91dea2c76
--- /dev/null
+++ b/examples/resources/test_output_tag_long
@@ -0,0 +1,2 @@
+6
+6
diff --git a/examples/resources/test_output_tag_short 
b/examples/resources/test_output_tag_short
new file mode 100644
index 000000000..72594ed96
--- /dev/null
+++ b/examples/resources/test_output_tag_short
@@ -0,0 +1,3 @@
+foo
+bar
+qux
diff --git a/examples/resources/test_output_tag_very_long 
b/examples/resources/test_output_tag_very_long
new file mode 100644
index 000000000..22a28156a
--- /dev/null
+++ b/examples/resources/test_output_tag_very_long
@@ -0,0 +1,2 @@
+foobarbaz
+ipsumlorem
diff --git 
a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
 
b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
index b9b979596..58c3fc6f7 100644
--- 
a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
+++ 
b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
@@ -25,6 +25,7 @@
 import 
edu.snu.nemo.common.ir.vertex.executionproperty.DynamicOptimizationProperty;
 import edu.snu.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
 import edu.snu.nemo.common.ir.vertex.executionproperty.ScheduleGroupProperty;
+import 
edu.snu.nemo.common.ir.vertex.executionproperty.AdditionalTagOutputProperty;
 import edu.snu.nemo.conf.JobConf;
 import edu.snu.nemo.common.dag.DAG;
 import edu.snu.nemo.common.dag.DAGBuilder;
@@ -58,6 +59,7 @@ private PhysicalPlanGenerator(final StagePartitioner 
stagePartitioner,
     this.dagDirectory = dagDirectory;
     this.stagePartitioner = stagePartitioner;
     stagePartitioner.addIgnoredPropertyKey(DynamicOptimizationProperty.class);
+    stagePartitioner.addIgnoredPropertyKey(AdditionalTagOutputProperty.class);
   }
 
   /**
diff --git 
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputCollectorImpl.java
 
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputCollectorImpl.java
index 84e52ad80..32b9352ff 100644
--- 
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputCollectorImpl.java
+++ 
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputCollectorImpl.java
@@ -18,6 +18,9 @@
 import edu.snu.nemo.common.ir.OutputCollector;
 
 import java.util.ArrayDeque;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.List;
 import java.util.Queue;
 
 /**
@@ -26,23 +29,41 @@
  * @param <O> output type.
  */
 public final class OutputCollectorImpl<O> implements OutputCollector<O> {
-  private final Queue<O> outputQueue;
+  private final Queue<O> mainTagOutputQueue;
+  private final Map<String, Queue<Object>> additionalTagOutputQueues;
 
   /**
    * Constructor of a new OutputCollectorImpl.
    */
   public OutputCollectorImpl() {
-    this.outputQueue = new ArrayDeque<>(1);
+    this.mainTagOutputQueue = new ArrayDeque<>(1);
+    this.additionalTagOutputQueues = new HashMap<>();
+  }
+
+  /**
+   * Constructor of a new OutputCollectorImpl with tagged outputs.
+   * @param taggedChildren tagged children
+   */
+  public OutputCollectorImpl(final List<String> taggedChildren) {
+    this.mainTagOutputQueue = new ArrayDeque<>(1);
+    this.additionalTagOutputQueues = new HashMap<>();
+    taggedChildren.forEach(child -> this.additionalTagOutputQueues.put(child, 
new ArrayDeque<>(1)));
   }
 
   @Override
   public void emit(final O output) {
-    outputQueue.add(output);
+    mainTagOutputQueue.add(output);
   }
 
   @Override
-  public void emit(final String dstVertexId, final Object output) {
-    throw new UnsupportedOperationException("emit(dstVertexId, output) in 
OutputCollectorImpl.");
+  public <T> void emit(final String dstVertexId, final T output) {
+    if (this.additionalTagOutputQueues.get(dstVertexId) == null) {
+      // This dstVertexId is for the main tag
+      emit((O) output);
+    } else {
+      // Note that String#hashCode() can be cached, thus accessing additional 
output queues can be fast.
+      this.additionalTagOutputQueues.get(dstVertexId).add(output);
+    }
   }
 
   /**
@@ -52,7 +73,25 @@ public void emit(final String dstVertexId, final Object 
output) {
    * @return the first element of this list
    */
   public O remove() {
-    return outputQueue.remove();
+    return mainTagOutputQueue.remove();
+  }
+
+  /**
+   * Inter-task data is transferred from sender-side Task's OutputCollectorImpl
+   * to receiver-side Task.
+   *
+   * @param tag output tag
+   * @return the first element of corresponding list
+   */
+  public Object remove(final String tag) {
+    if (this.additionalTagOutputQueues.get(tag) == null) {
+      // This dstVertexId is for the main tag
+      return remove();
+    } else {
+      // Note that String#hashCode() can be cached, thus accessing additional 
output queues can be fast.
+      return this.additionalTagOutputQueues.get(tag).remove();
+    }
+
   }
 
   /**
@@ -61,6 +100,21 @@ public O remove() {
    * @return true if this OutputCollector is empty.
    */
   public boolean isEmpty() {
-    return outputQueue.isEmpty();
+    return mainTagOutputQueue.isEmpty();
+  }
+
+  /**
+   * Check if this OutputCollector is empty.
+   *
+   * @param tag output tag
+   * @return true if this OutputCollector is empty.
+   */
+  public boolean isEmpty(final String tag) {
+    if (this.additionalTagOutputQueues.get(tag) == null) {
+      return isEmpty();
+    } else {
+      // Note that String#hashCode() can be cached, thus accessing additional 
output queues can be fast.
+      return this.additionalTagOutputQueues.get(tag).isEmpty();
+    }
   }
 }
diff --git 
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
 
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
index c320785a0..78974b458 100644
--- 
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
+++ 
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
@@ -21,6 +21,7 @@
 import edu.snu.nemo.common.dag.DAG;
 import edu.snu.nemo.common.ir.Readable;
 import edu.snu.nemo.common.ir.vertex.*;
+import 
edu.snu.nemo.common.ir.vertex.executionproperty.AdditionalTagOutputProperty;
 import edu.snu.nemo.common.ir.vertex.transform.Transform;
 import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
 import edu.snu.nemo.runtime.common.comm.ControlMessage;
@@ -151,11 +152,21 @@ public TaskExecutor(final Task task,
           .map(RuntimeEdge::isSideInput)
           .collect(Collectors.toList());
 
+      final Map<String, String> additionalOutputMap = irVertex
+          .getPropertyValue(AdditionalTagOutputProperty.class).orElse(new 
HashMap<>());
+      final List<Boolean> isToAdditionalTagOutputs = children.stream()
+          .map(harness -> harness.getIRVertex().getId())
+          .map(additionalOutputMap::containsValue)
+          .collect(Collectors.toList());
+
       // Handle writes
       final List<OutputWriter> childrenTaskWriters = getChildrenTaskWriters(
           taskIndex, irVertex, task.getTaskOutgoingEdges(), 
dataTransferFactory); // Children-task write
-      final VertexHarness vertexHarness = new VertexHarness(irVertex, new 
OutputCollectorImpl(), children,
-          isToSideInputs, childrenTaskWriters, new ContextImpl(sideInputMap)); 
// Intra-vertex write
+      final List<String> additionalOutputVertices = new 
ArrayList<>(additionalOutputMap.values());
+      final OutputCollectorImpl oci = new 
OutputCollectorImpl(additionalOutputVertices);
+      final VertexHarness vertexHarness = new VertexHarness(irVertex, oci, 
children,
+          isToSideInputs, isToAdditionalTagOutputs,
+          childrenTaskWriters, new ContextImpl(sideInputMap, 
additionalOutputMap)); // Intra-vertex write
       prepareTransform(vertexHarness);
       vertexIdToHarness.put(irVertex.getId(), vertexHarness);
 
@@ -202,11 +213,19 @@ private void processElementRecursively(final 
VertexHarness vertexHarness, final
     }
 
     // Given a single input element, a vertex can produce many output elements.
-    // Here, we recursively process all of the output elements.
+    // Here, we recursively process all of the main output elements.
     while (!outputCollector.isEmpty()) {
       final Object element = outputCollector.remove();
-      handleOutputElement(vertexHarness, element); // Recursion
+      handleMainOutputElement(vertexHarness, element); // Recursion
     }
+
+    // Recursively process all of the additional output elements.
+    vertexHarness.getAdditionalTagOutputChildren().keySet().forEach(tag -> {
+      while (!outputCollector.isEmpty(tag)) {
+        final Object element = outputCollector.remove(tag);
+        handleAdditionalOutputElement(vertexHarness, element, tag); // 
Recursion
+      }
+    });
   }
 
   /**
@@ -285,6 +304,8 @@ private void doExecute() {
         .flatMap(child -> 
getAllReachables(child).stream()).collect(Collectors.toList()));
     result.addAll(src.getSideInputChildren().stream()
         .flatMap(child -> 
getAllReachables(child).stream()).collect(Collectors.toList()));
+    result.addAll(src.getAdditionalTagOutputChildren().values().stream()
+        .flatMap(child -> 
getAllReachables(child).stream()).collect(Collectors.toList()));
     return result;
   }
 
@@ -292,17 +313,27 @@ private void finalizeVertex(final VertexHarness 
vertexHarness) {
     closeTransform(vertexHarness);
     while (!vertexHarness.getOutputCollector().isEmpty()) {
       final Object element = vertexHarness.getOutputCollector().remove();
-      handleOutputElement(vertexHarness, element);
+      handleMainOutputElement(vertexHarness, element);
     }
     finalizeOutputWriters(vertexHarness);
   }
 
-  private void handleOutputElement(final VertexHarness vertexHarness, final 
Object element) {
-    vertexHarness.getWritersToChildrenTasks().forEach(outputWriter -> 
outputWriter.write(element));
-    if (vertexHarness.getSideInputChildren().size() > 0) {
-      sideInputMap.put(((OperatorVertex) 
vertexHarness.getIRVertex()).getTransform().getTag(), element);
+  private void handleMainOutputElement(final VertexHarness harness, final 
Object element) {
+    harness.getWritersToChildrenTasks().forEach(outputWriter -> 
outputWriter.write(element));
+    if (harness.getSideInputChildren().size() > 0) {
+      sideInputMap.put(((OperatorVertex) 
harness.getIRVertex()).getTransform().getTag(), element);
     }
-    vertexHarness.getNonSideInputChildren().forEach(child -> 
processElementRecursively(child, element));
+    harness.getNonSideInputChildren().forEach(child -> 
processElementRecursively(child, element));
+  }
+
+  private void handleAdditionalOutputElement(final VertexHarness harness, 
final Object element, final String tag) {
+    // Inter-task writes are currently not supported.
+    if (harness.getSideInputChildren().size() > 0) {
+      sideInputMap.put(((OperatorVertex) 
harness.getIRVertex()).getTransform().getTag(), element);
+    }
+    harness.getAdditionalTagOutputChildren().entrySet().stream()
+        .filter(kv -> kv.getKey().equals(tag))
+        .forEach(kv -> processElementRecursively(kv.getValue(), element));
   }
 
   /**
@@ -462,5 +493,4 @@ private void finalizeOutputWriters(final VertexHarness 
vertexHarness) {
     }
     metricCollector.endMeasurement(irVertex.getId(), metric);
   }
-
 }
diff --git 
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java
 
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java
index 2d915c446..c5f9a7850 100644
--- 
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java
+++ 
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java
@@ -21,7 +21,9 @@
 import edu.snu.nemo.runtime.executor.datatransfer.OutputWriter;
 
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 
 /**
  * Captures the relationship between a non-source IRVertex's outputCollector, 
and children vertices.
@@ -35,31 +37,40 @@
   // These lists can be empty
   private final List<VertexHarness> sideInputChildren;
   private final List<VertexHarness> nonSideInputChildren;
+  private final Map<String, VertexHarness> additionalTagOutputChildren;
   private final List<OutputWriter> writersToChildrenTasks;
 
   VertexHarness(final IRVertex irVertex,
                 final OutputCollectorImpl outputCollector,
                 final List<VertexHarness> children,
                 final List<Boolean> isSideInputs,
+                final List<Boolean> isAdditionalTagOutputs,
                 final List<OutputWriter> writersToChildrenTasks,
                 final Transform.Context context) {
     this.irVertex = irVertex;
     this.outputCollector = outputCollector;
-    if (children.size() != isSideInputs.size()) {
+    if (children.size() != isSideInputs.size() || children.size() != 
isAdditionalTagOutputs.size()) {
       throw new IllegalStateException(irVertex.toString());
     }
+    final Map<String, String> taggedOutputMap = 
context.getAdditionalTagOutputs();
     final List<VertexHarness> sides = new ArrayList<>();
     final List<VertexHarness> nonSides = new ArrayList<>();
+    final Map<String, VertexHarness> tagged = new HashMap<>();
     for (int i = 0; i < children.size(); i++) {
       final VertexHarness child = children.get(i);
-      if (isSideInputs.get(0)) {
+      if (isSideInputs.get(i)) {
         sides.add(child);
+      } else if (isAdditionalTagOutputs.get(i)) {
+        taggedOutputMap.entrySet().stream()
+            .filter(kv -> child.getIRVertex().getId().equals(kv.getValue()))
+            .forEach(kv -> tagged.put(kv.getValue(), child));
       } else {
         nonSides.add(child);
       }
     }
     this.sideInputChildren = sides;
     this.nonSideInputChildren = nonSides;
+    this.additionalTagOutputChildren = tagged;
     this.writersToChildrenTasks = writersToChildrenTasks;
     this.context = context;
   }
@@ -92,6 +103,13 @@ OutputCollectorImpl getOutputCollector() {
     return sideInputChildren;
   }
 
+  /**
+   * @return map of tagged output children. (empty if none exists)
+   */
+  public Map<String, VertexHarness> getAdditionalTagOutputChildren() {
+    return additionalTagOutputChildren;
+  }
+
   /**
    * @return OutputWriters of this irVertex. (empty if none exists)
    */
diff --git 
a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
 
b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
index 96b755d27..b6b5bc4fd 100644
--- 
a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
+++ 
b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
@@ -24,6 +24,7 @@
 import edu.snu.nemo.common.ir.executionproperty.VertexExecutionProperty;
 import edu.snu.nemo.common.ir.vertex.InMemorySourceVertex;
 import edu.snu.nemo.common.ir.vertex.OperatorVertex;
+import 
edu.snu.nemo.common.ir.vertex.executionproperty.AdditionalTagOutputProperty;
 import edu.snu.nemo.common.ir.vertex.transform.Transform;
 import edu.snu.nemo.common.ir.executionproperty.ExecutionPropertyMap;
 import edu.snu.nemo.common.ir.vertex.IRVertex;
@@ -261,6 +262,67 @@ public void testTwoOperatorsWithSideInput() throws 
Exception {
     assertTrue(pairs.stream().map(Pair::left).allMatch(sideInput -> 
checkEqualElements(sideInput, values)));
   }
 
+  /**
+   * The DAG of the task to test looks like:
+   * parent vertex 1 --+-- vertex 2 (main tag)
+   *                   +-- vertex 3 (additional tag 1)
+   *                   +-- vertex 4 (additional tag 2)
+   *
+   * emit(element) and emit(dstVertexId, element) used together. emit(element) 
routes results to main output children,
+   * and emit(dstVertexId, element) routes results to corresponding additional 
output children.
+   */
+  @Test(timeout = 5000)
+  public void testAdditionalOutputs() throws Exception {
+    final IRVertex routerVertex = new OperatorVertex(new RoutingTransform());
+    final IRVertex mainVertex= new OperatorVertex(new RelayTransform());
+    final IRVertex bonusVertex1 = new OperatorVertex(new RelayTransform());
+    final IRVertex bonusVertex2 = new OperatorVertex(new RelayTransform());
+
+    // Tag to vertex map. Mock tags are used.
+    HashMap<String, String> tagToVertex = new HashMap<>();
+    tagToVertex.put("bonus1", bonusVertex1.getId());
+    tagToVertex.put("bonus2", bonusVertex2.getId());
+
+    routerVertex.setProperty(AdditionalTagOutputProperty.of(tagToVertex));
+
+    final DAG<IRVertex, RuntimeEdge<IRVertex>> taskDag = new 
DAGBuilder<IRVertex, RuntimeEdge<IRVertex>>()
+        .addVertex(routerVertex)
+        .addVertex(mainVertex)
+        .addVertex(bonusVertex1)
+        .addVertex(bonusVertex2)
+        .connectVertices(createEdge(routerVertex, mainVertex, false, "edge-1"))
+        .connectVertices(createEdge(routerVertex, bonusVertex1, false, 
"edge-2"))
+        .connectVertices(createEdge(routerVertex, bonusVertex2, false, 
"edge-3"))
+        .buildWithoutSourceSinkCheck();
+
+    final Task task = new Task(
+        "testAdditionalOutputs",
+        generateTaskId(),
+        0,
+        TASK_EXECUTION_PROPERTY_MAP,
+        new byte[0],
+        Collections.singletonList(mockStageEdgeTo(routerVertex)),
+        Arrays.asList(mockStageEdgeFrom(mainVertex),
+            mockStageEdgeFrom(bonusVertex1),
+            mockStageEdgeFrom(bonusVertex2)),
+        Collections.emptyMap());
+
+    // Execute the task.
+    final TaskExecutor taskExecutor = new TaskExecutor(
+        task, taskDag, taskStateManager, dataTransferFactory, 
metricMessageSender, persistentConnectionToMasterMap);
+    taskExecutor.execute();
+
+    // Check the output.
+    final List<Integer> mainOutputs = 
vertexIdToOutputData.get(mainVertex.getId());
+    final List<Integer> bonusOutputs1 = 
vertexIdToOutputData.get(bonusVertex1.getId());
+    final List<Integer> bonusOutputs2 = 
vertexIdToOutputData.get(bonusVertex1.getId());
+    List<Integer> even = elements.stream().filter(i -> i % 2 == 
0).collect(Collectors.toList());
+    List<Integer> odd = elements.stream().filter(i -> i % 2 != 
0).collect(Collectors.toList());
+    assertTrue(checkEqualElements(even, mainOutputs));
+    assertTrue(checkEqualElements(odd, bonusOutputs1));
+    assertTrue(checkEqualElements(odd, bonusOutputs2));
+  }
+
   private RuntimeEdge<IRVertex> createEdge(final IRVertex src,
                                            final IRVertex dst,
                                            final boolean isSideInput) {
@@ -271,6 +333,16 @@ public void testTwoOperatorsWithSideInput() throws 
Exception {
 
   }
 
+  private RuntimeEdge<IRVertex> createEdge(final IRVertex src,
+                                           final IRVertex dst,
+                                           final boolean isSideInput,
+                                           final String runtimeIREdgeId) {
+    ExecutionPropertyMap edgeProperties = new 
ExecutionPropertyMap(runtimeIREdgeId);
+    
edgeProperties.put(InterTaskDataStoreProperty.of(InterTaskDataStoreProperty.Value.MemoryStore));
+    return new RuntimeEdge<>(runtimeIREdgeId, edgeProperties, src, dst, 
isSideInput);
+
+  }
+
   private StageEdge mockStageEdgeFrom(final IRVertex irVertex) {
     final StageEdge edge = mock(StageEdge.class);
     when(edge.getSrcIRVertex()).thenReturn(irVertex);
@@ -416,6 +488,37 @@ public void close() {
     }
   }
 
+  /**
+   * Simple conditional identity function for testing additional outputs.
+   */
+  private class RoutingTransform implements Transform<Integer, Integer> {
+    private OutputCollector<Integer> outputCollector;
+    private Map<String, String> tagToVertex;
+
+    @Override
+    public void prepare(final Context context, OutputCollector<Integer> 
outputCollector) {
+      this.outputCollector = outputCollector;
+      this.tagToVertex = context.getAdditionalTagOutputs();
+    }
+
+    @Override
+    public void onData(final Integer element) {
+      final int i = element;
+      if (i % 2 == 0) {
+        // route to all main outputs. Invoked if user calls c.output(element)
+        outputCollector.emit(i);
+      } else {
+        // route to all additional outputs. Invoked if user calls 
c.output(tupleTag, element)
+        tagToVertex.values().forEach(vertex -> outputCollector.emit(vertex, 
i));
+      }
+    }
+
+    @Override
+    public void close() {
+      // Do nothing.
+    }
+  }
+
   /**
    * Gets a list of integer pair elements in range.
    * @param start value of the range (inclusive).
diff --git 
a/tests/src/test/java/edu/snu/nemo/runtime/common/plan/StagePartitionerTest.java
 
b/tests/src/test/java/edu/snu/nemo/runtime/common/plan/StagePartitionerTest.java
index 515bd9512..bcdc96a25 100644
--- 
a/tests/src/test/java/edu/snu/nemo/runtime/common/plan/StagePartitionerTest.java
+++ 
b/tests/src/test/java/edu/snu/nemo/runtime/common/plan/StagePartitionerTest.java
@@ -25,6 +25,7 @@
 import 
edu.snu.nemo.common.ir.vertex.executionproperty.ExecutorPlacementProperty;
 import edu.snu.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
 import edu.snu.nemo.common.ir.vertex.executionproperty.ScheduleGroupProperty;
+import 
edu.snu.nemo.common.ir.vertex.executionproperty.AdditionalTagOutputProperty;
 import edu.snu.nemo.common.ir.vertex.transform.Transform;
 import edu.snu.nemo.common.test.EmptyComponents;
 import org.apache.reef.tang.Tang;
@@ -52,6 +53,7 @@
   public void setup() throws InjectionException {
     stagePartitioner = 
Tang.Factory.getTang().newInjector().getInstance(StagePartitioner.class);
     stagePartitioner.addIgnoredPropertyKey(DynamicOptimizationProperty.class);
+    stagePartitioner.addIgnoredPropertyKey(AdditionalTagOutputProperty.class);
   }
 
   /**


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to