Repository: incubator-gobblin
Updated Branches:
  refs/heads/master 1f6e46f87 -> 95eff37ce


[GOBBLIN-616] Add ability to fork jobs when concatenating Dags.

Closes #2483 from sv2000/dependents


Project: http://git-wip-us.apache.org/repos/asf/incubator-gobblin/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-gobblin/commit/95eff37c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-gobblin/tree/95eff37c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-gobblin/diff/95eff37c

Branch: refs/heads/master
Commit: 95eff37ce927ea070594909b58360fd4f16cc1cb
Parents: 1f6e46f
Author: suvasude <[email protected]>
Authored: Mon Oct 22 13:42:43 2018 -0700
Committer: Hung Tran <[email protected]>
Committed: Mon Oct 22 13:42:43 2018 -0700

----------------------------------------------------------------------
 .../configuration/ConfigurationKeys.java        |   1 +
 .../service/modules/flow/FlowGraphPath.java     |  47 +++++--
 .../gobblin/service/modules/flowgraph/Dag.java  |  56 ++++++--
 .../service/modules/flow/FlowGraphPathTest.java | 135 +++++++++++++++++++
 .../service/modules/flowgraph/DagTest.java      |  95 ++++++++-----
 5 files changed, 276 insertions(+), 58 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-gobblin/blob/95eff37c/gobblin-api/src/main/java/org/apache/gobblin/configuration/ConfigurationKeys.java
----------------------------------------------------------------------
diff --git 
a/gobblin-api/src/main/java/org/apache/gobblin/configuration/ConfigurationKeys.java
 
b/gobblin-api/src/main/java/org/apache/gobblin/configuration/ConfigurationKeys.java
index 480796f..40d2f44 100644
--- 
a/gobblin-api/src/main/java/org/apache/gobblin/configuration/ConfigurationKeys.java
+++ 
b/gobblin-api/src/main/java/org/apache/gobblin/configuration/ConfigurationKeys.java
@@ -176,6 +176,7 @@ public class ConfigurationKeys {
   public static final String WORK_UNIT_ENABLE_TRACKING_LOGS = 
"workunit.enableTrackingLogs";
 
   public static final String JOB_DEPENDENCIES = "job.dependencies";
+  public static final String JOB_FORK_ON_CONCAT = "job.forkOnConcat";
   public static final String JOB_RUN_ONCE_KEY = "job.runonce";
   public static final String JOB_DISABLED_KEY = "job.disabled";
   public static final String JOB_JAR_FILES_KEY = "job.jars";

http://git-wip-us.apache.org/repos/asf/incubator-gobblin/blob/95eff37c/gobblin-service/src/main/java/org/apache/gobblin/service/modules/flow/FlowGraphPath.java
----------------------------------------------------------------------
diff --git 
a/gobblin-service/src/main/java/org/apache/gobblin/service/modules/flow/FlowGraphPath.java
 
b/gobblin-service/src/main/java/org/apache/gobblin/service/modules/flow/FlowGraphPath.java
index 0c28c3b..02004b6 100644
--- 
a/gobblin-service/src/main/java/org/apache/gobblin/service/modules/flow/FlowGraphPath.java
+++ 
b/gobblin-service/src/main/java/org/apache/gobblin/service/modules/flow/FlowGraphPath.java
@@ -21,15 +21,17 @@ import java.net.URI;
 import java.net.URISyntaxException;
 import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 import org.apache.hadoop.fs.Path;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Joiner;
 import com.google.common.base.Optional;
-import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
 import com.google.common.io.Files;
 import com.typesafe.config.Config;
@@ -45,6 +47,7 @@ import org.apache.gobblin.runtime.api.SpecExecutor;
 import org.apache.gobblin.runtime.api.SpecNotFoundException;
 import org.apache.gobblin.service.modules.dataset.DatasetDescriptor;
 import org.apache.gobblin.service.modules.flowgraph.Dag;
+import org.apache.gobblin.service.modules.flowgraph.Dag.DagNode;
 import org.apache.gobblin.service.modules.flowgraph.FlowEdge;
 import org.apache.gobblin.service.modules.spec.JobExecutionPlan;
 import org.apache.gobblin.service.modules.spec.JobExecutionPlanDagFactory;
@@ -95,22 +98,44 @@ public class FlowGraphPath {
    * @param dagRight The child dag.
    * @return The concatenated dag with modified {@link 
ConfigurationKeys#JOB_DEPENDENCIES}.
    */
-  private Dag<JobExecutionPlan> concatenate(Dag<JobExecutionPlan> dagLeft, 
Dag<JobExecutionPlan> dagRight) {
-    List<Dag.DagNode<JobExecutionPlan>> endNodes = dagLeft.getEndNodes();
-    List<Dag.DagNode<JobExecutionPlan>> startNodes = dagRight.getStartNodes();
+  @VisibleForTesting
+  static Dag<JobExecutionPlan> concatenate(Dag<JobExecutionPlan> dagLeft, 
Dag<JobExecutionPlan> dagRight) {
+    List<DagNode<JobExecutionPlan>> endNodes = dagLeft.getEndNodes();
+    List<DagNode<JobExecutionPlan>> startNodes = dagRight.getStartNodes();
     List<String> dependenciesList = Lists.newArrayList();
-    for (Dag.DagNode<JobExecutionPlan> dagNode: endNodes) {
-      
dependenciesList.add(dagNode.getValue().getJobSpec().getConfig().getString(ConfigurationKeys.JOB_NAME_KEY));
+    //List of nodes with no dependents in the concatenated dag.
+    Set<DagNode<JobExecutionPlan>> forkNodes = new HashSet<>();
+
+    for (DagNode<JobExecutionPlan> dagNode: endNodes) {
+      if (isNodeForkable(dagNode)) {
+        //If node is forkable, then add its parents (if any) to the 
dependencies list.
+        forkNodes.add(dagNode);
+        List<DagNode<JobExecutionPlan>> parentNodes = 
dagLeft.getParents(dagNode);
+        for (DagNode<JobExecutionPlan> parentNode: parentNodes) {
+          
dependenciesList.add(parentNode.getValue().getJobSpec().getConfig().getString(ConfigurationKeys.JOB_NAME_KEY));
+        }
+      } else {
+        
dependenciesList.add(dagNode.getValue().getJobSpec().getConfig().getString(ConfigurationKeys.JOB_NAME_KEY));
+      }
     }
-    String dependencies = Joiner.on(",").join(dependenciesList);
 
-    for (Dag.DagNode<JobExecutionPlan> childNode: startNodes) {
-      JobSpec jobSpec = childNode.getValue().getJobSpec();
-      
jobSpec.setConfig(jobSpec.getConfig().withValue(ConfigurationKeys.JOB_DEPENDENCIES,
 ConfigValueFactory.fromAnyRef(dependencies)));
+    if (!dependenciesList.isEmpty()) {
+      String dependencies = Joiner.on(",").join(dependenciesList);
+
+      for (DagNode<JobExecutionPlan> childNode : startNodes) {
+        JobSpec jobSpec = childNode.getValue().getJobSpec();
+        
jobSpec.setConfig(jobSpec.getConfig().withValue(ConfigurationKeys.JOB_DEPENDENCIES,
 ConfigValueFactory.fromAnyRef(dependencies)));
+      }
     }
 
-    return dagLeft.concatenate(dagRight);
+    return dagLeft.concatenate(dagRight, forkNodes);
   }
+
+  private static boolean isNodeForkable(DagNode<JobExecutionPlan> dagNode) {
+    Config jobConfig = dagNode.getValue().getJobSpec().getConfig();
+    return ConfigUtils.getBoolean(jobConfig, 
ConfigurationKeys.JOB_FORK_ON_CONCAT, false);
+  }
+
   /**
    * Given an instance of {@link FlowEdge}, this method returns a {@link Dag < 
JobExecutionPlan >} that moves data
    * from the source of the {@link FlowEdge} to the destination of the {@link 
FlowEdge}.

http://git-wip-us.apache.org/repos/asf/incubator-gobblin/blob/95eff37c/gobblin-service/src/main/java/org/apache/gobblin/service/modules/flowgraph/Dag.java
----------------------------------------------------------------------
diff --git 
a/gobblin-service/src/main/java/org/apache/gobblin/service/modules/flowgraph/Dag.java
 
b/gobblin-service/src/main/java/org/apache/gobblin/service/modules/flowgraph/Dag.java
index b41d1d0..0f5691e 100644
--- 
a/gobblin-service/src/main/java/org/apache/gobblin/service/modules/flowgraph/Dag.java
+++ 
b/gobblin-service/src/main/java/org/apache/gobblin/service/modules/flowgraph/Dag.java
@@ -20,8 +20,10 @@ package org.apache.gobblin.service.modules.flowgraph;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 import com.google.common.collect.Lists;
 
@@ -101,24 +103,57 @@ public class Dag<T> {
    * @return the concatenated dag
    */
   public Dag<T> concatenate(Dag<T> other) {
+    return concatenate(other, new HashSet<>());
+  }
+
+  /**
+   * Concatenate two dags together. Join the "other" dag to "this" dag and 
return "this" dag.
+   * The concatenate method ensures that all the jobs of "this" dag (which may 
have multiple end nodes)
+   * are completed before starting any job of the "other" dag. This is done by 
adding each endNode of this dag, which is
+   * not a fork node, as a parent of every startNode of the other dag.
+   *
+   * @param other dag to concatenate to this dag
+   * @param forkNodes a set of nodes from this dag which are marked as 
forkable nodes. Each of these nodes will be added
+   *                  to the list of end nodes of the concatenated dag. 
Essentially, a forkable node has no dependents
+   *                  in the concatenated dag.
+   * @return the concatenated dag
+   */
+  public Dag<T> concatenate(Dag<T> other, Set<DagNode<T>> forkNodes) {
     if (other == null || other.isEmpty()) {
       return this;
     }
     if (this.isEmpty()) {
       return other;
     }
+
     for (DagNode node : this.endNodes) {
-      this.parentChildMap.put(node, Lists.newArrayList());
-      for (DagNode otherNode : other.startNodes) {
-        this.parentChildMap.get(node).add(otherNode);
-        otherNode.addParentNode(node);
+      //Create a dependency for non-forkable nodes
+      if (!forkNodes.contains(node)) {
+        this.parentChildMap.put(node, Lists.newArrayList());
+        for (DagNode otherNode : other.startNodes) {
+          this.parentChildMap.get(node).add(otherNode);
+          otherNode.addParentNode(node);
+        }
+      } else {
+        for (DagNode otherNode: other.startNodes) {
+          List<DagNode<T>> parentNodes = this.getParents(node);
+          parentNodes.forEach(parentNode -> 
this.parentChildMap.get(parentNode).add(otherNode));
+          parentNodes.forEach(otherNode::addParentNode);
+        }
       }
-      this.endNodes = other.endNodes;
     }
+    //Each node which is a forkable node is added to list of end nodes of the 
concatenated dag
+    other.endNodes.addAll(forkNodes);
+    this.endNodes = other.endNodes;
+
     //Append all the entries from the other dag's parentChildMap to this dag's 
parentChildMap
-    for (Map.Entry<DagNode, List<DagNode<T>>> entry: 
other.parentChildMap.entrySet()) {
-      this.parentChildMap.put(entry.getKey(), entry.getValue());
-    }
+    this.parentChildMap.putAll(other.parentChildMap);
+
+    //If there exists a node in the other dag with no parent nodes, add it to 
the list of start nodes of the
+    // concatenated dag.
+    other.startNodes.stream().filter(node -> other.getParents(node).isEmpty())
+        .forEach(node -> this.startNodes.add(node));
+
     this.nodes.addAll(other.nodes);
     return this;
   }
@@ -177,10 +212,7 @@ public class Dag<T> {
         return false;
       }
       DagNode that = (DagNode) o;
-      if (!this.getValue().equals(that.getValue())) {
-        return false;
-      }
-      return true;
+      return this.getValue().equals(that.getValue());
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/incubator-gobblin/blob/95eff37c/gobblin-service/src/test/java/org/apache/gobblin/service/modules/flow/FlowGraphPathTest.java
----------------------------------------------------------------------
diff --git 
a/gobblin-service/src/test/java/org/apache/gobblin/service/modules/flow/FlowGraphPathTest.java
 
b/gobblin-service/src/test/java/org/apache/gobblin/service/modules/flow/FlowGraphPathTest.java
new file mode 100644
index 0000000..01074f8
--- /dev/null
+++ 
b/gobblin-service/src/test/java/org/apache/gobblin/service/modules/flow/FlowGraphPathTest.java
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gobblin.service.modules.flow;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+import com.typesafe.config.Config;
+import com.typesafe.config.ConfigValueFactory;
+
+import org.apache.gobblin.config.ConfigBuilder;
+import org.apache.gobblin.configuration.ConfigurationKeys;
+import org.apache.gobblin.runtime.api.JobSpec;
+import org.apache.gobblin.runtime.api.SpecExecutor;
+import org.apache.gobblin.runtime.spec_executorInstance.InMemorySpecExecutor;
+import org.apache.gobblin.service.modules.flowgraph.Dag;
+import org.apache.gobblin.service.modules.spec.JobExecutionPlan;
+import org.apache.gobblin.service.modules.spec.JobExecutionPlanDagFactory;
+
+public class FlowGraphPathTest {
+
+  /**
+   * A method to create a {@link Dag <JobExecutionPlan>}.
+   * @return a Dag.
+   */
+  public Dag<JobExecutionPlan> buildDag(int numNodes, int startNodeId, boolean 
isForkable) throws URISyntaxException {
+    List<JobExecutionPlan> jobExecutionPlans = new ArrayList<>();
+    Config baseConfig = ConfigBuilder.create().
+        addPrimitive(ConfigurationKeys.FLOW_GROUP_KEY, "group0").
+        addPrimitive(ConfigurationKeys.FLOW_NAME_KEY, "flow0").
+        addPrimitive(ConfigurationKeys.FLOW_EXECUTION_ID_KEY, 
System.currentTimeMillis()).
+        addPrimitive(ConfigurationKeys.JOB_GROUP_KEY, "group0").build();
+    for (int i = startNodeId; i < startNodeId + numNodes; i++) {
+      String suffix = Integer.toString(i);
+      Config jobConfig = baseConfig.withValue(ConfigurationKeys.JOB_NAME_KEY, 
ConfigValueFactory.fromAnyRef("job" + suffix));
+      if (isForkable && (i == startNodeId + numNodes - 1)) {
+        jobConfig = jobConfig.withValue(ConfigurationKeys.JOB_FORK_ON_CONCAT, 
ConfigValueFactory.fromAnyRef(true));
+      }
+      if (i > startNodeId) {
+        jobConfig = jobConfig.withValue(ConfigurationKeys.JOB_DEPENDENCIES, 
ConfigValueFactory.fromAnyRef("job" + (i  - 1)));
+      }
+      JobSpec js = JobSpec.builder("test_job" + 
suffix).withVersion(suffix).withConfig(jobConfig).
+          withTemplate(new URI("job" + suffix)).build();
+      SpecExecutor specExecutor = 
InMemorySpecExecutor.createDummySpecExecutor(new URI("job" + i));
+      JobExecutionPlan jobExecutionPlan = new JobExecutionPlan(js, 
specExecutor);
+      jobExecutionPlans.add(jobExecutionPlan);
+    }
+    return new JobExecutionPlanDagFactory().createDag(jobExecutionPlans);
+  }
+
+  @Test
+  public void testConcatenate() throws URISyntaxException {
+    //Dag1: "job0->job1", Dag2: "job2->job3"
+    Dag<JobExecutionPlan> dag1 = buildDag(2, 0, false);
+    Dag<JobExecutionPlan> dag2 = buildDag(2, 2, false);
+    Dag<JobExecutionPlan> dagNew = FlowGraphPath.concatenate(dag1, dag2);
+
+    //Expected result: "job0"->"job1"->"job2"->"job3"
+    Assert.assertEquals(dagNew.getStartNodes().size(), 1);
+    Assert.assertEquals(dagNew.getEndNodes().size(), 1);
+    Assert.assertEquals(dagNew.getNodes().size(), 4);
+    
Assert.assertEquals(dagNew.getStartNodes().get(0).getValue().getJobSpec().getConfig().getString(ConfigurationKeys.JOB_NAME_KEY),
 "job0");
+    
Assert.assertEquals(dagNew.getEndNodes().get(0).getValue().getJobSpec().getConfig().getString(ConfigurationKeys.JOB_NAME_KEY),
 "job3");
+    
Assert.assertEquals(dagNew.getEndNodes().get(0).getValue().getJobSpec().getConfig().getString(ConfigurationKeys.JOB_DEPENDENCIES),
 "job2");
+
+    //Dag1: "job0", Dag2: "job1->job2", "job0" forkable
+    dag1 = buildDag(1, 0, true);
+    dag2 = buildDag(2, 1, false);
+    dagNew = FlowGraphPath.concatenate(dag1, dag2);
+
+    //Expected result: "job0", "job1" -> "job2"
+    Assert.assertEquals(dagNew.getStartNodes().size(), 2);
+    Assert.assertEquals(dagNew.getEndNodes().size(), 2);
+    Assert.assertEquals(dagNew.getNodes().size(), 3);
+    
Assert.assertEquals(dagNew.getStartNodes().get(0).getValue().getJobSpec().getConfig().getString(ConfigurationKeys.JOB_NAME_KEY),
 "job0");
+    
Assert.assertEquals(dagNew.getStartNodes().get(1).getValue().getJobSpec().getConfig().getString(ConfigurationKeys.JOB_NAME_KEY),
 "job1");
+    
Assert.assertEquals(dagNew.getEndNodes().get(0).getValue().getJobSpec().getConfig().getString(ConfigurationKeys.JOB_NAME_KEY),
 "job2");
+    
Assert.assertEquals(dagNew.getEndNodes().get(0).getValue().getJobSpec().getConfig().getString(ConfigurationKeys.JOB_DEPENDENCIES),
 "job1");
+    
Assert.assertEquals(dagNew.getEndNodes().get(1).getValue().getJobSpec().getConfig().getString(ConfigurationKeys.JOB_NAME_KEY),
 "job0");
+    
Assert.assertFalse(dagNew.getEndNodes().get(1).getValue().getJobSpec().getConfig().hasPath(ConfigurationKeys.JOB_DEPENDENCIES));
+
+    //Dag1: "job0->job1", Dag2: "job2->job3", "job1" forkable
+    dag1 = buildDag(2, 0, true);
+    dag2 = buildDag(2, 2, false);
+    dagNew = FlowGraphPath.concatenate(dag1, dag2);
+
+    //Expected result: "job0" -> "job1"
+    //                        \-> "job2" -> "job3"
+    Assert.assertEquals(dagNew.getStartNodes().size(), 1);
+    Assert.assertEquals(dagNew.getEndNodes().size(), 2);
+    Assert.assertEquals(dagNew.getNodes().size(), 4);
+    
Assert.assertEquals(dagNew.getStartNodes().get(0).getValue().getJobSpec().getConfig().getString(ConfigurationKeys.JOB_NAME_KEY),
 "job0");
+    
Assert.assertEquals(dagNew.getEndNodes().get(0).getValue().getJobSpec().getConfig().getString(ConfigurationKeys.JOB_NAME_KEY),
 "job3");
+    
Assert.assertEquals(dagNew.getEndNodes().get(0).getValue().getJobSpec().getConfig().getString(ConfigurationKeys.JOB_DEPENDENCIES),
 "job2");
+    
Assert.assertEquals(dagNew.getEndNodes().get(1).getValue().getJobSpec().getConfig().getString(ConfigurationKeys.JOB_NAME_KEY),
 "job1");
+    
Assert.assertEquals(dagNew.getEndNodes().get(1).getValue().getJobSpec().getConfig().getString(ConfigurationKeys.JOB_DEPENDENCIES),
 "job0");
+
+    //Dag1: "job0", Dag2: "job1"
+    dag1 = buildDag(1, 0, true);
+    dag2 = buildDag(1, 1, false);
+    dagNew = FlowGraphPath.concatenate(dag1, dag2);
+
+    //Expected result: "job0","job1"
+    Assert.assertEquals(dagNew.getStartNodes().size(), 2);
+    Assert.assertEquals(dagNew.getEndNodes().size(), 2);
+    Assert.assertEquals(dagNew.getNodes().size(), 2);
+    
Assert.assertEquals(dagNew.getStartNodes().get(0).getValue().getJobSpec().getConfig().getString(ConfigurationKeys.JOB_NAME_KEY),
 "job0");
+    
Assert.assertEquals(dagNew.getStartNodes().get(1).getValue().getJobSpec().getConfig().getString(ConfigurationKeys.JOB_NAME_KEY),
 "job1");
+    
Assert.assertFalse(dagNew.getStartNodes().get(1).getValue().getJobSpec().getConfig().hasPath(ConfigurationKeys.JOB_DEPENDENCIES));
+    
Assert.assertFalse(dagNew.getStartNodes().get(1).getValue().getJobSpec().getConfig().hasPath(ConfigurationKeys.JOB_DEPENDENCIES));
+
+
+  }
+
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-gobblin/blob/95eff37c/gobblin-service/src/test/java/org/apache/gobblin/service/modules/flowgraph/DagTest.java
----------------------------------------------------------------------
diff --git 
a/gobblin-service/src/test/java/org/apache/gobblin/service/modules/flowgraph/DagTest.java
 
b/gobblin-service/src/test/java/org/apache/gobblin/service/modules/flowgraph/DagTest.java
index 48922d0..3a7ab10 100644
--- 
a/gobblin-service/src/test/java/org/apache/gobblin/service/modules/flowgraph/DagTest.java
+++ 
b/gobblin-service/src/test/java/org/apache/gobblin/service/modules/flowgraph/DagTest.java
@@ -23,21 +23,24 @@ import java.util.Set;
 
 import org.testng.Assert;
 import org.testng.annotations.Test;
+import org.testng.collections.Sets;
 
 import com.google.common.collect.Lists;
 
 import lombok.extern.slf4j.Slf4j;
 
+import org.apache.gobblin.service.modules.flowgraph.Dag.DagNode;
+
 
 @Slf4j
 public class DagTest {
   @Test
   public void testInitialize() {
-    Dag.DagNode<String> dagNode1 = new Dag.DagNode<>("val1");
-    Dag.DagNode<String> dagNode2 = new Dag.DagNode<>("val2");
-    Dag.DagNode<String> dagNode3 = new Dag.DagNode<>("val3");
-    Dag.DagNode<String> dagNode4 = new Dag.DagNode<>("val4");
-    Dag.DagNode<String> dagNode5 = new Dag.DagNode<>("val5");
+    DagNode<String> dagNode1 = new DagNode<>("val1");
+    DagNode<String> dagNode2 = new DagNode<>("val2");
+    DagNode<String> dagNode3 = new DagNode<>("val3");
+    DagNode<String> dagNode4 = new DagNode<>("val4");
+    DagNode<String> dagNode5 = new DagNode<>("val5");
 
     dagNode2.addParentNode(dagNode1);
     dagNode3.addParentNode(dagNode1);
@@ -45,7 +48,7 @@ public class DagTest {
     dagNode4.addParentNode(dagNode3);
     dagNode5.addParentNode(dagNode3);
 
-    List<Dag.DagNode<String>> dagNodeList = Lists.newArrayList(dagNode1, 
dagNode2, dagNode3, dagNode4, dagNode5);
+    List<DagNode<String>> dagNodeList = Lists.newArrayList(dagNode1, dagNode2, 
dagNode3, dagNode4, dagNode5);
     Dag<String> dag = new Dag<>(dagNodeList);
     //Test startNodes and endNodes
     Assert.assertEquals(dag.getStartNodes().size(), 1);
@@ -55,10 +58,10 @@ public class DagTest {
     Assert.assertEquals(dag.getEndNodes().get(1).getValue(), "val5");
 
 
-    Dag.DagNode startNode = dag.getStartNodes().get(0);
+    DagNode startNode = dag.getStartNodes().get(0);
     Assert.assertEquals(dag.getChildren(startNode).size(), 2);
     Set<String> childSet = new HashSet<>();
-    for (Dag.DagNode<String> node: dag.getChildren(startNode)) {
+    for (DagNode<String> node: dag.getChildren(startNode)) {
       childSet.add(node.getValue());
     }
     Assert.assertTrue(childSet.contains("val2"));
@@ -70,7 +73,7 @@ public class DagTest {
     Assert.assertEquals(dag.getChildren(dagNode2).size(), 1);
     Assert.assertEquals(dag.getChildren(dagNode2).get(0).getValue(), "val4");
 
-    for (Dag.DagNode<String> node: dag.getChildren(dagNode3)) {
+    for (DagNode<String> node: dag.getChildren(dagNode3)) {
       childSet.add(node.getValue());
     }
     Assert.assertTrue(childSet.contains("val4"));
@@ -83,11 +86,11 @@ public class DagTest {
 
   @Test
   public void testConcatenate() {
-    Dag.DagNode<String> dagNode1 = new Dag.DagNode<>("val1");
-    Dag.DagNode<String> dagNode2 = new Dag.DagNode<>("val2");
-    Dag.DagNode<String> dagNode3 = new Dag.DagNode<>("val3");
-    Dag.DagNode<String> dagNode4 = new Dag.DagNode<>("val4");
-    Dag.DagNode<String> dagNode5 = new Dag.DagNode<>("val5");
+    DagNode<String> dagNode1 = new DagNode<>("val1");
+    DagNode<String> dagNode2 = new DagNode<>("val2");
+    DagNode<String> dagNode3 = new DagNode<>("val3");
+    DagNode<String> dagNode4 = new DagNode<>("val4");
+    DagNode<String> dagNode5 = new DagNode<>("val5");
 
     dagNode2.addParentNode(dagNode1);
     dagNode3.addParentNode(dagNode1);
@@ -95,21 +98,20 @@ public class DagTest {
     dagNode4.addParentNode(dagNode3);
     dagNode5.addParentNode(dagNode3);
 
-    List<Dag.DagNode<String>> dagNodeList = Lists.newArrayList(dagNode1, 
dagNode2, dagNode3, dagNode4, dagNode5);
+    List<DagNode<String>> dagNodeList = Lists.newArrayList(dagNode1, dagNode2, 
dagNode3, dagNode4, dagNode5);
     Dag<String> dag1 = new Dag<>(dagNodeList);
 
-    Dag.DagNode<String> dagNode6 = new Dag.DagNode<>("val6");
-    Dag.DagNode<String> dagNode7 = new Dag.DagNode<>("val7");
-    Dag.DagNode<String> dagNode8 = new Dag.DagNode<>("val8");
+    DagNode<String> dagNode6 = new DagNode<>("val6");
+    DagNode<String> dagNode7 = new DagNode<>("val7");
+    DagNode<String> dagNode8 = new DagNode<>("val8");
     dagNode8.addParentNode(dagNode6);
     dagNode8.addParentNode(dagNode7);
     Dag<String> dag2 = new Dag<>(Lists.newArrayList(dagNode6, dagNode7, 
dagNode8));
 
-    //Concatenate the two dags
     Dag<String> dagNew = dag1.concatenate(dag2);
 
     //Ensure end nodes of first dag are no longer end nodes
-    for (Dag.DagNode<String> dagNode: Lists.newArrayList(dagNode6, dagNode7)) {
+    for (DagNode<String> dagNode : Lists.newArrayList(dagNode6, dagNode7)) {
       Assert.assertEquals(dagNew.getParents(dagNode).size(), 2);
       Set<String> set = new HashSet<>();
       set.add(dagNew.getParents(dagNode).get(0).getValue());
@@ -118,7 +120,7 @@ public class DagTest {
       Assert.assertTrue(set.contains("val5"));
     }
 
-    for (Dag.DagNode<String> dagNode: Lists.newArrayList(dagNode4, dagNode5)) {
+    for (DagNode<String> dagNode : Lists.newArrayList(dagNode4, dagNode5)) {
       Assert.assertEquals(dagNew.getChildren(dagNode).size(), 2);
       Set<String> set = new HashSet<>();
       set.add(dagNew.getChildren(dagNode).get(0).getValue());
@@ -127,8 +129,8 @@ public class DagTest {
       Assert.assertTrue(set.contains("val7"));
     }
 
-    for (Dag.DagNode<String> dagNode: Lists.newArrayList(dagNode6, dagNode7)) {
-      List<Dag.DagNode<String>> nextNodes = dagNew.getChildren(dagNode);
+    for (DagNode<String> dagNode : Lists.newArrayList(dagNode6, dagNode7)) {
+      List<DagNode<String>> nextNodes = dagNew.getChildren(dagNode);
       Assert.assertEquals(nextNodes.size(), 1);
       Assert.assertEquals(nextNodes.get(0).getValue(), "val8");
     }
@@ -142,12 +144,35 @@ public class DagTest {
   }
 
   @Test
+  public void testConcatenateForkNodes() {
+    DagNode<String> dagNode1 = new DagNode<>("val1");
+    DagNode<String> dagNode2 = new DagNode<>("val2");
+    DagNode<String> dagNode3 = new DagNode<>("val3");
+
+    dagNode2.addParentNode(dagNode1);
+    dagNode3.addParentNode(dagNode1);
+
+    Dag<String> dag1 = new Dag<>(Lists.newArrayList(dagNode1, dagNode2, 
dagNode3));
+    DagNode<String> dagNode4 = new DagNode<>("val4");
+    Dag<String> dag2 = new Dag<>(Lists.newArrayList(dagNode4));
+
+    Set<DagNode<String>> forkNodes = Sets.newHashSet();
+    forkNodes.add(dagNode3);
+    Dag<String> dagNew = dag1.concatenate(dag2, forkNodes);
+
+    Assert.assertEquals(dagNew.getEndNodes().size(), 2);
+    Assert.assertEquals(dagNew.getEndNodes().get(0).getValue(), "val4");
+    Assert.assertEquals(dagNew.getEndNodes().get(1).getValue(), "val3");
+    Assert.assertEquals(dagNew.getChildren(dagNode3).size(), 0);
+  }
+
+  @Test
   public void testMerge() {
-    Dag.DagNode<String> dagNode1 = new Dag.DagNode<>("val1");
-    Dag.DagNode<String> dagNode2 = new Dag.DagNode<>("val2");
-    Dag.DagNode<String> dagNode3 = new Dag.DagNode<>("val3");
-    Dag.DagNode<String> dagNode4 = new Dag.DagNode<>("val4");
-    Dag.DagNode<String> dagNode5 = new Dag.DagNode<>("val5");
+    DagNode<String> dagNode1 = new DagNode<>("val1");
+    DagNode<String> dagNode2 = new DagNode<>("val2");
+    DagNode<String> dagNode3 = new DagNode<>("val3");
+    DagNode<String> dagNode4 = new DagNode<>("val4");
+    DagNode<String> dagNode5 = new DagNode<>("val5");
 
     dagNode2.addParentNode(dagNode1);
     dagNode3.addParentNode(dagNode1);
@@ -155,12 +180,12 @@ public class DagTest {
     dagNode4.addParentNode(dagNode3);
     dagNode5.addParentNode(dagNode3);
 
-    List<Dag.DagNode<String>> dagNodeList = Lists.newArrayList(dagNode1, 
dagNode2, dagNode3, dagNode4, dagNode5);
+    List<DagNode<String>> dagNodeList = Lists.newArrayList(dagNode1, dagNode2, 
dagNode3, dagNode4, dagNode5);
     Dag<String> dag1 = new Dag<>(dagNodeList);
 
-    Dag.DagNode<String> dagNode6 = new Dag.DagNode<>("val6");
-    Dag.DagNode<String> dagNode7 = new Dag.DagNode<>("val7");
-    Dag.DagNode<String> dagNode8 = new Dag.DagNode<>("val8");
+    DagNode<String> dagNode6 = new DagNode<>("val6");
+    DagNode<String> dagNode7 = new DagNode<>("val7");
+    DagNode<String> dagNode8 = new DagNode<>("val8");
     dagNode8.addParentNode(dagNode6);
     dagNode8.addParentNode(dagNode7);
     Dag<String> dag2 = new Dag<>(Lists.newArrayList(dagNode6, dagNode7, 
dagNode8));
@@ -170,11 +195,11 @@ public class DagTest {
 
     //Test the startNodes
     Assert.assertEquals(dagNew.getStartNodes().size(), 3);
-    for (Dag.DagNode<String> dagNode: Lists.newArrayList(dagNode1, dagNode6, 
dagNode7)) {
+    for (DagNode<String> dagNode: Lists.newArrayList(dagNode1, dagNode6, 
dagNode7)) {
       Assert.assertTrue(dagNew.getStartNodes().contains(dagNode));
       Assert.assertEquals(dagNew.getParents(dagNode).size(), 0);
       if (dagNode == dagNode1) {
-        List<Dag.DagNode<String>> nextNodes = dagNew.getChildren(dagNode);
+        List<DagNode<String>> nextNodes = dagNew.getChildren(dagNode);
         Assert.assertEquals(nextNodes.size(), 2);
         Assert.assertTrue(nextNodes.contains(dagNode2));
         Assert.assertTrue(nextNodes.contains(dagNode3));
@@ -186,7 +211,7 @@ public class DagTest {
 
     //Test the endNodes
     Assert.assertEquals(dagNew.getEndNodes().size(), 3);
-    for (Dag.DagNode<String> dagNode: Lists.newArrayList(dagNode4, dagNode5, 
dagNode8)) {
+    for (DagNode<String> dagNode: Lists.newArrayList(dagNode4, dagNode5, 
dagNode8)) {
       Assert.assertTrue(dagNew.getEndNodes().contains(dagNode));
       Assert.assertEquals(dagNew.getChildren(dagNode).size(), 0);
       if (dagNode == dagNode8) {

Reply via email to