This is an automated email from the ASF dual-hosted git repository.

arjun4084346 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/gobblin.git


The following commit(s) were added to refs/heads/master by this push:
     new 125d787b3 [GOBBLIN-2023] add kill dag proc (#3901)
125d787b3 is described below

commit 125d787b32ce75e61f25d525499107f758b9ff42
Author: Arjun Singh Bora <[email protected]>
AuthorDate: Fri Apr 5 10:30:07 2024 -0700

    [GOBBLIN-2023] add kill dag proc (#3901)
    
    * add kill dag proc
    * address review comment
---
 .../spec_executorInstance/MockedSpecExecutor.java  |   3 +-
 .../orchestration/DagManagementTaskStreamImpl.java |   3 +
 .../modules/orchestration/DagProcFactory.java      |   6 +
 .../modules/orchestration/DagTaskVisitor.java      |   2 +
 .../MostlyMySqlDagManagementStateStore.java        |  11 +-
 .../modules/orchestration/proc/KillDagProc.java    | 127 +++++++++++++++++
 .../{DagTaskVisitor.java => task/KillDagTask.java} |  23 ++-
 .../DagManagementDagActionStoreChangeMonitor.java  |   3 +-
 .../orchestration/proc/KillDagProcTest.java        | 154 +++++++++++++++++++++
 9 files changed, 320 insertions(+), 12 deletions(-)

diff --git 
a/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/spec_executorInstance/MockedSpecExecutor.java
 
b/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/spec_executorInstance/MockedSpecExecutor.java
index 378d04717..77d21c14d 100644
--- 
a/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/spec_executorInstance/MockedSpecExecutor.java
+++ 
b/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/spec_executorInstance/MockedSpecExecutor.java
@@ -45,7 +45,8 @@ public class MockedSpecExecutor extends InMemorySpecExecutor {
     when(mockedSpecProducer.addSpec(any())).thenReturn(new 
CompletedFuture(Boolean.TRUE, null));
     when(mockedSpecProducer.serializeAddSpecResponse(any())).thenReturn("");
     when(mockedSpecProducer.deserializeAddSpecResponse(any())).thenReturn(new 
CompletedFuture(Boolean.TRUE, null));
-  }
+    when(mockedSpecProducer.cancelJob(any(), any())).thenReturn(new 
CompletedFuture(Boolean.TRUE, null));
+    }
 
   public static SpecExecutor createDummySpecExecutor(URI uri) {
     Properties properties = new Properties();
diff --git 
a/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/DagManagementTaskStreamImpl.java
 
b/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/DagManagementTaskStreamImpl.java
index 8ad246aa2..63aca3e8d 100644
--- 
a/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/DagManagementTaskStreamImpl.java
+++ 
b/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/DagManagementTaskStreamImpl.java
@@ -40,6 +40,7 @@ import org.apache.gobblin.metrics.MetricContext;
 import org.apache.gobblin.metrics.event.EventSubmitter;
 import org.apache.gobblin.runtime.util.InjectionNames;
 import org.apache.gobblin.service.modules.orchestration.task.DagTask;
+import org.apache.gobblin.service.modules.orchestration.task.KillDagTask;
 import org.apache.gobblin.service.modules.orchestration.task.LaunchDagTask;
 import org.apache.gobblin.service.modules.orchestration.task.ReevaluateDagTask;
 import org.apache.gobblin.util.ConfigUtils;
@@ -164,6 +165,8 @@ public class DagManagementTaskStreamImpl implements 
DagManagement, DagTaskStream
         return new LaunchDagTask(dagAction, leaseObtainedStatus, 
dagActionStore.get());
       case REEVALUATE:
         return new ReevaluateDagTask(dagAction, leaseObtainedStatus, 
dagActionStore.get());
+      case KILL:
+        return new KillDagTask(dagAction, leaseObtainedStatus, 
dagActionStore.get());
       default:
         throw new UnsupportedOperationException(dagActionType + " not yet 
implemented");
     }
diff --git 
a/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/DagProcFactory.java
 
b/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/DagProcFactory.java
index ef7aea3d2..cf48b3d1c 100644
--- 
a/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/DagProcFactory.java
+++ 
b/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/DagProcFactory.java
@@ -22,9 +22,11 @@ import com.google.inject.Singleton;
 
 import org.apache.gobblin.annotation.Alpha;
 import org.apache.gobblin.service.modules.orchestration.proc.DagProc;
+import org.apache.gobblin.service.modules.orchestration.proc.KillDagProc;
 import org.apache.gobblin.service.modules.orchestration.proc.LaunchDagProc;
 import org.apache.gobblin.service.modules.orchestration.proc.ReevaluateDagProc;
 import org.apache.gobblin.service.modules.orchestration.task.DagTask;
+import org.apache.gobblin.service.modules.orchestration.task.KillDagTask;
 import org.apache.gobblin.service.modules.orchestration.task.LaunchDagTask;
 import org.apache.gobblin.service.modules.orchestration.task.ReevaluateDagTask;
 import 
org.apache.gobblin.service.modules.utils.FlowCompilationValidationHelper;
@@ -56,6 +58,10 @@ public class DagProcFactory implements 
DagTaskVisitor<DagProc> {
   public ReevaluateDagProc meet(ReevaluateDagTask reEvaluateDagTask) {
     return new ReevaluateDagProc(reEvaluateDagTask);
   }
+
+  public KillDagProc meet(KillDagTask killDagTask) {
+    return new KillDagProc(killDagTask);
+  }
   //todo - overload meet method for other dag tasks
 }
 
diff --git 
a/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/DagTaskVisitor.java
 
b/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/DagTaskVisitor.java
index e946f9835..bbba21037 100644
--- 
a/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/DagTaskVisitor.java
+++ 
b/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/DagTaskVisitor.java
@@ -17,6 +17,7 @@
 
 package org.apache.gobblin.service.modules.orchestration;
 
+import org.apache.gobblin.service.modules.orchestration.task.KillDagTask;
 import org.apache.gobblin.service.modules.orchestration.task.LaunchDagTask;
 import org.apache.gobblin.service.modules.orchestration.task.ReevaluateDagTask;
 
@@ -24,4 +25,5 @@ import 
org.apache.gobblin.service.modules.orchestration.task.ReevaluateDagTask;
 public interface DagTaskVisitor<T> {
   T meet(LaunchDagTask launchDagTask);
   T meet(ReevaluateDagTask reevaluateDagTask);
+  T meet(KillDagTask killDagTask);
 }
diff --git 
a/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/MostlyMySqlDagManagementStateStore.java
 
b/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/MostlyMySqlDagManagementStateStore.java
index b010a0cb5..ae5b5d79f 100644
--- 
a/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/MostlyMySqlDagManagementStateStore.java
+++ 
b/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/MostlyMySqlDagManagementStateStore.java
@@ -174,14 +174,16 @@ public class MostlyMySqlDagManagementStateStore 
implements DagManagementStateSto
   }
 
   @Override
-  // todo - updating different mapps here and in addDagNodeState can result in 
inconsistency between the maps
+  // todo - updating different maps here and in addDagNodeState can result in 
inconsistency between the maps
   public synchronized void deleteDagNodeState(DagManager.DagId dagId, 
Dag.DagNode<JobExecutionPlan> dagNode) {
     this.jobToDag.remove(dagNode);
     this.dagNodes.remove(dagNode.getValue().getId());
     this.dagToDeadline.remove(dagId);
-    this.dagToJobs.get(dagId).remove(dagNode);
-    if (this.dagToJobs.get(dagId).isEmpty()) {
-      this.dagToJobs.remove(dagId);
+    if (this.dagToJobs.containsKey(dagId)) {
+      this.dagToJobs.get(dagId).remove(dagNode);
+      if (this.dagToJobs.get(dagId).isEmpty()) {
+        this.dagToJobs.remove(dagId);
+      }
     }
   }
 
@@ -211,6 +213,7 @@ public class MostlyMySqlDagManagementStateStore implements 
DagManagementStateSto
     return this.dagStateStore.existsDag(dagId);
   }
 
+  @Override
   public Pair<Optional<Dag.DagNode<JobExecutionPlan>>, Optional<JobStatus>> 
getDagNodeWithJobStatus(DagNodeId dagNodeId) {
     if (this.dagNodes.containsKey(dagNodeId)) {
       return ImmutablePair.of(Optional.of(this.dagNodes.get(dagNodeId)), 
getJobStatus(dagNodeId));
diff --git 
a/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/proc/KillDagProc.java
 
b/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/proc/KillDagProc.java
new file mode 100644
index 000000000..5f1c7bb97
--- /dev/null
+++ 
b/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/proc/KillDagProc.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.gobblin.service.modules.orchestration.proc;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+import java.util.concurrent.Future;
+
+import com.google.common.collect.Maps;
+
+import lombok.extern.slf4j.Slf4j;
+
+import org.apache.gobblin.configuration.ConfigurationKeys;
+import org.apache.gobblin.metrics.event.TimingEvent;
+import org.apache.gobblin.service.modules.flowgraph.Dag;
+import org.apache.gobblin.service.modules.orchestration.DagActionStore;
+import 
org.apache.gobblin.service.modules.orchestration.DagManagementStateStore;
+import org.apache.gobblin.service.modules.orchestration.DagManagerUtils;
+import org.apache.gobblin.service.modules.orchestration.TimingEventUtils;
+import org.apache.gobblin.service.modules.orchestration.task.KillDagTask;
+import org.apache.gobblin.service.modules.spec.JobExecutionPlan;
+
+import static org.apache.gobblin.service.ExecutionStatus.CANCELLED;
+
+
+/**
+ * An implementation for {@link DagProc} that kills all the nodes of a dag.
+ * If the dag action has job name set, then it kills only that particular 
job/dagNode.
+ */
+@Slf4j
+public class KillDagProc extends DagProc<Optional<Dag<JobExecutionPlan>>> {
+  private final boolean shouldKillSpecificJob;
+
+  public KillDagProc(KillDagTask killDagTask) {
+    super(killDagTask);
+    this.shouldKillSpecificJob = 
!getDagNodeId().getJobName().equals(DagActionStore.NO_JOB_NAME_DEFAULT);
+  }
+
+  @Override
+  protected Optional<Dag<JobExecutionPlan>> initialize(DagManagementStateStore 
dagManagementStateStore)
+      throws IOException {
+   return dagManagementStateStore.getDag(getDagId());
+  }
+
+  @Override
+  protected void act(DagManagementStateStore dagManagementStateStore, 
Optional<Dag<JobExecutionPlan>> dag)
+      throws IOException {
+    log.info("Request to kill dag {} (node: {})", getDagId(), 
shouldKillSpecificJob ? getDagNodeId() : "<<all>>");
+
+    if (!dag.isPresent()) {
+      // todo - add a metric here
+      log.error("Did not find Dag with id {}, it might be already 
cancelled/finished and thus cleaned up from the store.", getDagId());
+      return;
+    }
+
+    dag.get().setFlowEvent(TimingEvent.FlowTimings.FLOW_CANCELLED);
+    dag.get().setMessage("Flow killed by request");
+
+    dagManagementStateStore.checkpointDag(dag.get());
+
+    if (this.shouldKillSpecificJob) {
+      Optional<Dag.DagNode<JobExecutionPlan>> dagNodeToCancel = 
dagManagementStateStore.getDagNodeWithJobStatus(this.dagNodeId).getLeft();
+      if (dagNodeToCancel.isPresent()) {
+        cancelDagNode(dagNodeToCancel.get());
+      } else {
+        // todo - add a metric here
+        log.error("Did not find Dag node with id {}, it might be already 
cancelled/finished and thus cleaned up from the store.", getDagNodeId());
+      }
+    } else {
+      List<Dag.DagNode<JobExecutionPlan>> dagNodesToCancel = 
dag.get().getNodes();
+      log.info("Found {} DagNodes to cancel (DagId {}).", 
dagNodesToCancel.size(), getDagId());
+
+      for (Dag.DagNode<JobExecutionPlan> dagNodeToCancel : dagNodesToCancel) {
+        cancelDagNode(dagNodeToCancel);
+        // todo - why was it not being cleaned up in DagManager?
+        dagManagementStateStore.deleteDagNodeState(getDagId(), 
dagNodeToCancel);
+      }
+    }
+  }
+
+  private void cancelDagNode(Dag.DagNode<JobExecutionPlan> dagNodeToCancel) 
throws IOException {
+    Properties props = new Properties();
+    if 
(dagNodeToCancel.getValue().getJobSpec().getConfig().hasPath(ConfigurationKeys.FLOW_EXECUTION_ID_KEY))
 {
+      props.setProperty(ConfigurationKeys.FLOW_EXECUTION_ID_KEY,
+          
dagNodeToCancel.getValue().getJobSpec().getConfig().getString(ConfigurationKeys.FLOW_EXECUTION_ID_KEY));
+    }
+
+    try {
+      if (dagNodeToCancel.getValue().getJobFuture().isPresent()) {
+        Future future = dagNodeToCancel.getValue().getJobFuture().get();
+        String serializedFuture = 
DagManagerUtils.getSpecProducer(dagNodeToCancel).serializeAddSpecResponse(future);
+        props.put(ConfigurationKeys.SPEC_PRODUCER_SERIALIZED_FUTURE, 
serializedFuture);
+        sendCancellationEvent(dagNodeToCancel.getValue());
+      } else {
+        log.warn("No Job future when canceling DAG node (hence, not sending 
cancellation event) - {}",
+            dagNodeToCancel.getValue().getJobSpec().getUri());
+      }
+      
DagManagerUtils.getSpecProducer(dagNodeToCancel).cancelJob(dagNodeToCancel.getValue().getJobSpec().getUri(),
 props).get();
+    } catch (Exception e) {
+      throw new IOException(e);
+    }
+  }
+
+  private void sendCancellationEvent(JobExecutionPlan jobExecutionPlan) {
+    Map<String, String> jobMetadata = 
TimingEventUtils.getJobMetadata(Maps.newHashMap(), jobExecutionPlan);
+    
eventSubmitter.getTimingEvent(TimingEvent.LauncherTimings.JOB_CANCEL).stop(jobMetadata);
+    jobExecutionPlan.setExecutionStatus(CANCELLED);
+  }
+}
diff --git 
a/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/DagTaskVisitor.java
 
b/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/task/KillDagTask.java
similarity index 53%
copy from 
gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/DagTaskVisitor.java
copy to 
gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/task/KillDagTask.java
index e946f9835..9533a0e69 100644
--- 
a/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/DagTaskVisitor.java
+++ 
b/gobblin-service/src/main/java/org/apache/gobblin/service/modules/orchestration/task/KillDagTask.java
@@ -15,13 +15,24 @@
  * limitations under the License.
  */
 
-package org.apache.gobblin.service.modules.orchestration;
+package org.apache.gobblin.service.modules.orchestration.task;
 
-import org.apache.gobblin.service.modules.orchestration.task.LaunchDagTask;
-import org.apache.gobblin.service.modules.orchestration.task.ReevaluateDagTask;
+import org.apache.gobblin.service.modules.orchestration.DagActionStore;
+import org.apache.gobblin.service.modules.orchestration.DagTaskVisitor;
+import org.apache.gobblin.service.modules.orchestration.LeaseAttemptStatus;
 
 
-public interface DagTaskVisitor<T> {
-  T meet(LaunchDagTask launchDagTask);
-  T meet(ReevaluateDagTask reevaluateDagTask);
+/**
+ * A {@link DagTask} responsible for killing running jobs.
+ */
+
+public class KillDagTask extends DagTask {
+  public KillDagTask(DagActionStore.DagAction dagAction, 
LeaseAttemptStatus.LeaseObtainedStatus leaseObtainedStatus,
+      DagActionStore dagActionStore) {
+    super(dagAction, leaseObtainedStatus, dagActionStore);
+  }
+
+  public <T> T host(DagTaskVisitor<T> visitor) {
+    return visitor.meet(this);
+  }
 }
diff --git 
a/gobblin-service/src/main/java/org/apache/gobblin/service/monitoring/DagManagementDagActionStoreChangeMonitor.java
 
b/gobblin-service/src/main/java/org/apache/gobblin/service/monitoring/DagManagementDagActionStoreChangeMonitor.java
index 413fbcdc5..6582f03f1 100644
--- 
a/gobblin-service/src/main/java/org/apache/gobblin/service/monitoring/DagManagementDagActionStoreChangeMonitor.java
+++ 
b/gobblin-service/src/main/java/org/apache/gobblin/service/monitoring/DagManagementDagActionStoreChangeMonitor.java
@@ -62,10 +62,11 @@ public class DagManagementDagActionStoreChangeMonitor 
extends DagActionStoreChan
       switch (dagAction.getDagActionType()) {
         case LAUNCH :
         case REEVALUATE :
+        case KILL :
           dagManagement.addDagAction(dagAction);
           break;
         default:
-          log.warn("Received unsupported dagAction {}. Expected to be a 
REEVALUATE or LAUNCH", dagAction.getDagActionType());
+          log.warn("Received unsupported dagAction {}. Expected to be a KILL, 
REEVALUATE or LAUNCH", dagAction.getDagActionType());
           this.unexpectedErrors.mark();
       }
     } catch (IOException e) {
diff --git 
a/gobblin-service/src/test/java/org/apache/gobblin/service/modules/orchestration/proc/KillDagProcTest.java
 
b/gobblin-service/src/test/java/org/apache/gobblin/service/modules/orchestration/proc/KillDagProcTest.java
new file mode 100644
index 000000000..be7019ba3
--- /dev/null
+++ 
b/gobblin-service/src/test/java/org/apache/gobblin/service/modules/orchestration/proc/KillDagProcTest.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.gobblin.service.modules.orchestration.proc;
+
+import java.io.IOException;
+import java.net.URISyntaxException;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.ExecutionException;
+import java.util.stream.Collectors;
+
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.mockito.Mockito;
+import org.testng.Assert;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import com.typesafe.config.ConfigFactory;
+import com.typesafe.config.ConfigValueFactory;
+
+import org.apache.gobblin.configuration.ConfigurationKeys;
+import org.apache.gobblin.metastore.testing.TestMetastoreDatabaseFactory;
+import org.apache.gobblin.runtime.api.FlowSpec;
+import org.apache.gobblin.runtime.api.Spec;
+import org.apache.gobblin.runtime.api.SpecProducer;
+import org.apache.gobblin.service.ExecutionStatus;
+import org.apache.gobblin.service.modules.flowgraph.Dag;
+import org.apache.gobblin.service.modules.orchestration.DagActionStore;
+import org.apache.gobblin.service.modules.orchestration.DagManager;
+import org.apache.gobblin.service.modules.orchestration.DagManagerTest;
+import org.apache.gobblin.service.modules.orchestration.DagManagerUtils;
+import 
org.apache.gobblin.service.modules.orchestration.MostlyMySqlDagManagementStateStore;
+import 
org.apache.gobblin.service.modules.orchestration.MostlyMySqlDagManagementStateStoreTest;
+import org.apache.gobblin.service.modules.orchestration.MysqlDagActionStore;
+import org.apache.gobblin.service.modules.orchestration.task.KillDagTask;
+import org.apache.gobblin.service.modules.orchestration.task.LaunchDagTask;
+import org.apache.gobblin.service.modules.spec.JobExecutionPlan;
+import 
org.apache.gobblin.service.modules.utils.FlowCompilationValidationHelper;
+import org.apache.gobblin.service.monitoring.JobStatus;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+
+
+public class KillDagProcTest {
+  private MostlyMySqlDagManagementStateStore dagManagementStateStore;
+
+  @BeforeClass
+  public void setUp() throws Exception {
+    this.dagManagementStateStore = 
spy(MostlyMySqlDagManagementStateStoreTest.getDummyDMSS(TestMetastoreDatabaseFactory.get()));
+    
doReturn(FlowSpec.builder().build()).when(this.dagManagementStateStore).getFlowSpec(any());
+    doNothing().when(this.dagManagementStateStore).tryAcquireQuota(any());
+    doNothing().when(this.dagManagementStateStore).addDagNodeState(any(), 
any());
+  }
+
+  @Test
+  public void killDag() throws IOException, URISyntaxException, 
InterruptedException {
+    long flowExecutionId = System.currentTimeMillis();
+    Dag<JobExecutionPlan> dag = DagManagerTest.buildDag("1", flowExecutionId, 
DagManager.FailureOption.FINISH_ALL_POSSIBLE.name(),
+        5, "user5", 
ConfigFactory.empty().withValue(ConfigurationKeys.FLOW_GROUP_KEY, 
ConfigValueFactory.fromAnyRef("fg")));
+    FlowCompilationValidationHelper flowCompilationValidationHelper = 
mock(FlowCompilationValidationHelper.class);
+    doReturn(Optional.of(dag)).when(dagManagementStateStore).getDag(any());
+    
doReturn(com.google.common.base.Optional.of(dag)).when(flowCompilationValidationHelper).createExecutionPlanIfValid(any());
+
+    LaunchDagProc launchDagProc = new LaunchDagProc(new LaunchDagTask(new 
DagActionStore.DagAction("fg", "flow1",
+        String.valueOf(flowExecutionId), 
MysqlDagActionStore.NO_JOB_NAME_DEFAULT, DagActionStore.DagActionType.LAUNCH),
+        null, mock(DagActionStore.class)), flowCompilationValidationHelper);
+    launchDagProc.process(this.dagManagementStateStore);
+
+    List<SpecProducer<Spec>> specProducers = dag.getNodes().stream().map(n -> {
+      try {
+        return DagManagerUtils.getSpecProducer(n);
+      } catch (ExecutionException | InterruptedException e) {
+        throw new RuntimeException(e);
+      }
+    }).collect(Collectors.toList());
+
+    KillDagProc killDagProc = new KillDagProc(new KillDagTask(new 
DagActionStore.DagAction("fg", "flow1",
+       String.valueOf(flowExecutionId), 
MysqlDagActionStore.NO_JOB_NAME_DEFAULT, DagActionStore.DagActionType.KILL),
+        null, mock(DagActionStore.class)));
+    killDagProc.process(this.dagManagementStateStore);
+
+    long cancelJobCount = specProducers.stream()
+        .mapToLong(p -> Mockito.mockingDetails(p)
+            .getInvocations()
+            .stream()
+            .filter(a -> a.getMethod().getName().equals("cancelJob"))
+            .count())
+        .sum();
+    // kill dag proc tries to cancel all the dag nodes
+    Assert.assertEquals(cancelJobCount, 5);
+  }
+
+  @Test
+  public void killDagNode() throws IOException, URISyntaxException, 
InterruptedException {
+    long flowExecutionId = System.currentTimeMillis();
+    Dag<JobExecutionPlan> dag = DagManagerTest.buildDag("2", flowExecutionId, 
DagManager.FailureOption.FINISH_ALL_POSSIBLE.name(),
+        5, "user5", 
ConfigFactory.empty().withValue(ConfigurationKeys.FLOW_GROUP_KEY, 
ConfigValueFactory.fromAnyRef("fg")));
+    FlowCompilationValidationHelper flowCompilationValidationHelper = 
mock(FlowCompilationValidationHelper.class);
+    JobStatus
+        jobStatus = 
JobStatus.builder().flowName("job0").flowGroup("fg").jobGroup("fg").jobName("job0").flowExecutionId(flowExecutionId).
+        message("Test 
message").eventName(ExecutionStatus.COMPLETE.name()).startTime(flowExecutionId).shouldRetry(false).orchestratedTime(flowExecutionId).build();
+
+    doReturn(Optional.of(dag)).when(dagManagementStateStore).getDag(any());
+    doReturn(new ImmutablePair<>(Optional.of(dag.getStartNodes().get(0)), 
Optional.of(jobStatus))).when(dagManagementStateStore).getDagNodeWithJobStatus(any());
+    
doReturn(com.google.common.base.Optional.of(dag)).when(flowCompilationValidationHelper).createExecutionPlanIfValid(any());
+
+    LaunchDagProc launchDagProc = new LaunchDagProc(new LaunchDagTask(new 
DagActionStore.DagAction("fg", "flow2",
+        String.valueOf(flowExecutionId), 
MysqlDagActionStore.NO_JOB_NAME_DEFAULT, DagActionStore.DagActionType.LAUNCH),
+        null, mock(DagActionStore.class)), flowCompilationValidationHelper);
+    launchDagProc.process(this.dagManagementStateStore);
+
+    List<SpecProducer<Spec>> specProducers = dag.getNodes().stream().map(n -> {
+      try {
+        return DagManagerUtils.getSpecProducer(n);
+      } catch (ExecutionException | InterruptedException e) {
+        throw new RuntimeException(e);
+      }
+    }).collect(Collectors.toList());
+
+    KillDagProc killDagProc = new KillDagProc(new KillDagTask(new 
DagActionStore.DagAction("fg", "flow2",
+        String.valueOf(flowExecutionId), "job2", 
DagActionStore.DagActionType.KILL),
+        null, mock(DagActionStore.class)));
+    killDagProc.process(this.dagManagementStateStore);
+
+    long cancelJobCount = specProducers.stream()
+        .mapToLong(p -> Mockito.mockingDetails(p)
+            .getInvocations()
+            .stream()
+            .filter(a -> a.getMethod().getName().equals("cancelJob"))
+            .count())
+        .sum();
+    // kill dag proc tries to cancel only the exact dag node that was provided
+    Assert.assertEquals(cancelJobCount, 1);
+  }
+}

Reply via email to