This is an automated email from the ASF dual-hosted git repository.

suvasude pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/gobblin.git


The following commit(s) were added to refs/heads/master by this push:
     new f3f8a36  [GOBBLIN-1384] Fix task cancellation to ensure task commit is 
invoked only after task completes[]
f3f8a36 is described below

commit f3f8a36559c7118f55c69f70f35294ba4a90a03d
Author: suvasude <[email protected]>
AuthorDate: Thu Feb 18 14:04:27 2021 -0800

    [GOBBLIN-1384] Fix task cancellation to ensure task commit is invoked only 
after task completes[]
    
    Closes #3224 from sv2000/taskCancelCommit
---
 .../main/java/org/apache/gobblin/runtime/Task.java |  14 +-
 .../runtime/GobblinMultiTaskAttemptTest.java       |   9 +-
 .../java/org/apache/gobblin/runtime/TaskTest.java  | 240 +++++++++++++++++----
 3 files changed, 205 insertions(+), 58 deletions(-)

diff --git a/gobblin-runtime/src/main/java/org/apache/gobblin/runtime/Task.java 
b/gobblin-runtime/src/main/java/org/apache/gobblin/runtime/Task.java
index e1fbc30..6b0e822 100644
--- a/gobblin-runtime/src/main/java/org/apache/gobblin/runtime/Task.java
+++ b/gobblin-runtime/src/main/java/org/apache/gobblin/runtime/Task.java
@@ -397,13 +397,9 @@ public class Task implements TaskIFace {
       failTask(t);
     } finally {
       synchronized (this) {
-        if (this.taskFuture == null || !this.taskFuture.isCancelled()) {
-          this.taskStateTracker.onTaskRunCompletion(this);
-          completeShutdown();
-          this.taskFuture = null;
-        } else {
-          LOG.info("will not decrease count down latch as this task is 
cancelled");
-        }
+        this.taskStateTracker.onTaskRunCompletion(this);
+        completeShutdown();
+        this.taskFuture = null;
       }
     }
   }
@@ -564,7 +560,7 @@ public class Task implements TaskIFace {
     this.lastRecordPulledTimestampMillis = System.currentTimeMillis();
   }
 
-  private void failTask(Throwable t) {
+  protected void failTask(Throwable t) {
     LOG.error(String.format("Task %s failed", this.taskId), t);
     this.taskState.setWorkingState(WorkUnitState.WorkingState.FAILED);
     this.taskState.setProp(ConfigurationKeys.TASK_FAILURE_EXCEPTION_KEY, 
Throwables.getStackTraceAsString(t));
@@ -1039,8 +1035,6 @@ public class Task implements TaskIFace {
   public synchronized boolean cancel() {
     LOG.info("Calling task cancel with interrupt flag: {}", 
this.shouldInterruptTaskOnCancel);
     if (this.taskFuture != null && 
this.taskFuture.cancel(this.shouldInterruptTaskOnCancel)) {
-      this.taskStateTracker.onTaskRunCompletion(this);
-      this.completeShutdown();
       return true;
     } else {
       return false;
diff --git 
a/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/GobblinMultiTaskAttemptTest.java
 
b/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/GobblinMultiTaskAttemptTest.java
index b658801..79ed7f4 100644
--- 
a/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/GobblinMultiTaskAttemptTest.java
+++ 
b/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/GobblinMultiTaskAttemptTest.java
@@ -17,7 +17,6 @@
 
 package org.apache.gobblin.runtime;
 
-import java.io.IOException;
 import java.util.List;
 import java.util.Properties;
 
@@ -95,7 +94,7 @@ public class GobblinMultiTaskAttemptTest {
   @Test
   public void testRunWithTaskStatsTrackerNotScheduledFailure()
       throws Exception {
-    TaskStateTracker stateTracker = new FailingTestStateTracker(new 
Properties(), log);
+    TaskStateTracker stateTracker = new DummyTestStateTracker(new 
Properties(), log);
     // Preparing Instance of TaskAttempt with designed failure on task creation
     WorkUnit tmpWU = new WorkUnit();
     // Put necessary attributes in workunit
@@ -122,8 +121,8 @@ public class GobblinMultiTaskAttemptTest {
     Assert.fail();
   }
 
-  public static class FailingTestStateTracker extends AbstractTaskStateTracker 
{
-    public FailingTestStateTracker(Properties properties, Logger logger) {
+  public static class DummyTestStateTracker extends AbstractTaskStateTracker {
+    public DummyTestStateTracker(Properties properties, Logger logger) {
       super(properties, logger);
     }
 
@@ -134,7 +133,7 @@ public class GobblinMultiTaskAttemptTest {
 
     @Override
     public void onTaskRunCompletion(Task task) {
-
+      task.markTaskCompletion();
     }
 
     @Override
diff --git 
a/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/TaskTest.java 
b/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/TaskTest.java
index 60f368c..8d91f8f 100644
--- a/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/TaskTest.java
+++ b/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/TaskTest.java
@@ -17,13 +17,6 @@
 
 package org.apache.gobblin.runtime;
 
-import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.anyInt;
-import static org.mockito.Mockito.doNothing;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.spy;
-import static org.mockito.Mockito.when;
-
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -34,9 +27,11 @@ import java.util.Properties;
 import java.util.Random;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
 import java.util.concurrent.atomic.AtomicBoolean;
 
-import org.apache.gobblin.runtime.util.TaskMetrics;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileSystem;
 import org.testng.Assert;
@@ -59,13 +54,21 @@ import 
org.apache.gobblin.qualitychecker.row.RowLevelPolicyCheckResults;
 import org.apache.gobblin.qualitychecker.row.RowLevelPolicyChecker;
 import org.apache.gobblin.qualitychecker.task.TaskLevelPolicyCheckResults;
 import org.apache.gobblin.qualitychecker.task.TaskLevelPolicyChecker;
-import org.apache.gobblin.source.extractor.DataRecordException;
+import org.apache.gobblin.runtime.util.TaskMetrics;
 import org.apache.gobblin.source.extractor.Extractor;
 import org.apache.gobblin.source.workunit.Extract;
 import org.apache.gobblin.source.workunit.WorkUnit;
+import org.apache.gobblin.testing.AssertWithBackoff;
 import org.apache.gobblin.writer.DataWriter;
 import org.apache.gobblin.writer.DataWriterBuilder;
 
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+
 
 /**
  * Integration tests for {@link Task}.
@@ -292,8 +295,6 @@ public class TaskTest {
 
   }
 
-
-
   private ArrayList<ArrayList<Object>> runTaskAndGetResults(TaskState 
taskState, int numRecords, int numForks,
       ForkOperator mockForkOperator)
       throws Exception {
@@ -399,6 +400,139 @@ public class TaskTest {
   }
 
   /**
+   * A test that calls {@link Task#cancel()} while {@link Task#run()} is 
executing. Ensures that the countdown latch
+   * is decremented and TaskState is set to FAILED.
+   * @throws Exception
+   */
+  @Test
+  public void testTaskCancelBeforeCompletion()
+      throws Exception {
+    // Create a TaskState
+    TaskState taskState = getEmptyTestTaskState("testCancelBeforeCompletion");
+
+    int numRecords = -1;
+    int numForks = 1;
+    ForkOperator mockForkOperator = new RoundRobinForkOperator(numForks);
+
+    ArrayList<ArrayList<Object>> recordCollectors = new ArrayList<>(numForks);
+    for (int i=0; i < numForks; ++i) {
+      recordCollectors.add(new ArrayList<>());
+    }
+
+    TaskContext mockTaskContext = getMockTaskContext(taskState,
+        new StringExtractor(numRecords), recordCollectors, mockForkOperator);
+
+    // Create a dummy TaskStateTracker
+    TaskStateTracker dummyTaskStateTracker = new 
GobblinMultiTaskAttemptTest.DummyTestStateTracker(new Properties(), log);
+
+    // Create a TaskExecutor - a real TaskExecutor must be created so a Fork 
is run in a separate thread
+    TaskExecutor taskExecutor = new TaskExecutor(new Properties());
+
+    CountUpAndDownLatch countDownLatch = new CountUpAndDownLatch(0);
+    // Create the Task
+    Task task = new DelayedFailureTask(mockTaskContext, dummyTaskStateTracker, 
taskExecutor, Optional.of(countDownLatch));
+    //Increment the countDownLatch to signal a new task creation.
+    countDownLatch.countUp();
+
+    ExecutorService executorService = Executors.newSingleThreadExecutor();
+    Future taskFuture = executorService.submit(new Thread(() -> task.run()));
+    task.setTaskFuture(taskFuture);
+
+    //Wait for task to enter RUNNING state
+    AssertWithBackoff.create().maxSleepMs(10).timeoutMs(1000).backoffFactor(1)
+        .assertTrue(input -> task.getWorkingState() == 
WorkUnitState.WorkingState.RUNNING,
+            "Waiting for task to enter RUNNING state");
+
+    Assert.assertEquals(countDownLatch.getCount(), 1);
+
+    task.shutdown();
+
+    //Ensure task is still RUNNING, since shutdown() is a NO-OP and the 
extractor should continue.
+    Assert.assertEquals(countDownLatch.getCount(), 1);
+    Assert.assertEquals(taskState.getWorkingState(), 
WorkUnitState.WorkingState.RUNNING);
+
+    //Call task cancel
+    task.cancel();
+
+    //Ensure task is still RUNNING immediately after cancel() due to the delay 
introduced in task failure handling.
+    Assert.assertEquals(countDownLatch.getCount(), 1);
+    Assert.assertEquals(taskState.getWorkingState(), 
WorkUnitState.WorkingState.RUNNING);
+
+    //Ensure countDownLatch is eventually counted down to 0
+    AssertWithBackoff.create().maxSleepMs(100).timeoutMs(5000).backoffFactor(1)
+        .assertTrue(input -> countDownLatch.getCount() == 0, "Waiting for the 
task to complete.");
+
+    //Ensure the TaskState is set to FAILED
+    Assert.assertEquals(taskState.getWorkingState(), 
WorkUnitState.WorkingState.FAILED);
+  }
+
+  /**
+   * A test that calls {@link Task#cancel()} after {@link Task#run()} is 
completed. In this case the cancel() method should
+   * be a NO-OP and should leave the task state unchanged.
+   * @throws Exception
+   */
+  @Test
+  public void testTaskCancelAfterCompletion()
+      throws Exception {
+    // Create a TaskState
+    TaskState taskState = getEmptyTestTaskState("testCancelAfterCompletion");
+
+    int numRecords = -1;
+    int numForks = 1;
+    ForkOperator mockForkOperator = new RoundRobinForkOperator(numForks);
+
+    ArrayList<ArrayList<Object>> recordCollectors = new ArrayList<>(numForks);
+    for (int i=0; i < numForks; ++i) {
+      recordCollectors.add(new ArrayList<>());
+    }
+
+    TaskContext mockTaskContext = getMockTaskContext(taskState,
+        new StringExtractor(numRecords, false), recordCollectors, 
mockForkOperator);
+
+    // Create a dummy TaskStateTracker
+    TaskStateTracker dummyTaskStateTracker = new 
GobblinMultiTaskAttemptTest.DummyTestStateTracker(new Properties(), log);
+
+    // Create a TaskExecutor - a real TaskExecutor must be created so a Fork 
is run in a separate thread
+    TaskExecutor taskExecutor = new TaskExecutor(new Properties());
+
+    CountUpAndDownLatch countDownLatch = new CountUpAndDownLatch(0);
+    // Create the Task
+    Task task = new Task(mockTaskContext, dummyTaskStateTracker, taskExecutor, 
Optional.of(countDownLatch));
+    //Increment the countDownLatch to signal a new task creation.
+    countDownLatch.countUp();
+
+    ExecutorService executorService = Executors.newSingleThreadExecutor();
+    Future taskFuture = executorService.submit(new Thread(() -> task.run()));
+    task.setTaskFuture(taskFuture);
+
+    //Wait for task to enter RUNNING state
+    AssertWithBackoff.create().maxSleepMs(10).timeoutMs(1000).backoffFactor(1)
+        .assertTrue(input -> task.getWorkingState() == 
WorkUnitState.WorkingState.RUNNING,
+            "Waiting for task to enter RUNNING state");
+
+    Assert.assertEquals(countDownLatch.getCount(), 1);
+
+    task.shutdown();
+
+    //Ensure countDownLatch is counted down to 0 i.e. task is done.
+    AssertWithBackoff.create().maxSleepMs(100).timeoutMs(5000).backoffFactor(1)
+        .assertTrue(input -> countDownLatch.getCount() == 0, "Waiting for the 
task to complete.");
+
+    //Ensure the TaskState is RUNNING
+    Assert.assertEquals(taskState.getWorkingState(), 
WorkUnitState.WorkingState.RUNNING);
+
+    //Call task cancel
+    task.cancel();
+
+    //Ensure the TaskState is unchanged on cancel()
+    Assert.assertEquals(taskState.getWorkingState(), 
WorkUnitState.WorkingState.RUNNING);
+
+    //Ensure task state is successful on commit()
+    task.commit();
+    Assert.assertEquals(taskState.getWorkingState(), 
WorkUnitState.WorkingState.SUCCESSFUL);
+  }
+
+  /**
    * An implementation of {@link Extractor} that throws an {@link IOException} 
during the invocation of
    * {@link #readRecord(Object)}.
    */
@@ -407,12 +541,12 @@ public class TaskTest {
     private final AtomicBoolean HAS_FAILED = new AtomicBoolean();
 
     @Override
-    public Object getSchema() throws IOException {
+    public Object getSchema() {
       return null;
     }
 
     @Override
-    public Object readRecord(@Deprecated Object reuse) throws 
DataRecordException, IOException {
+    public Object readRecord(@Deprecated Object reuse) throws IOException {
       if (!HAS_FAILED.get()) {
         HAS_FAILED.set(true);
         throw new IOException("Injected failure");
@@ -431,31 +565,36 @@ public class TaskTest {
     }
 
     @Override
-    public void close() throws IOException {
+    public void close() {
       // Do nothing
     }
   }
 
-
   private static class StringExtractor implements Extractor<Object, String> {
-
+    //Num records to extract. If set to -1, it is treated as an unbounded 
extractor.
     private final int _numRecords;
     private int _currentRecord;
+    private boolean _shouldIgnoreShutdown = true;
+    private AtomicBoolean _shutdownRequested = new AtomicBoolean(false);
+
     public StringExtractor(int numRecords) {
+      this(numRecords, true);
+    }
+
+    public StringExtractor(int numRecords, boolean shouldIgnoreShutdown) {
       _numRecords = numRecords;
       _currentRecord = -1;
+      _shouldIgnoreShutdown = shouldIgnoreShutdown;
     }
 
     @Override
-    public Object getSchema()
-        throws IOException {
+    public Object getSchema() {
       return "";
     }
 
     @Override
-    public String readRecord(@Deprecated String reuse)
-        throws DataRecordException, IOException {
-      if (_currentRecord < _numRecords-1) {
+    public String readRecord(@Deprecated String reuse) {
+      if (!_shutdownRequested.get() && (_numRecords == -1 || _currentRecord < 
_numRecords-1)) {
         _currentRecord++;
         return "" + _currentRecord;
       } else {
@@ -474,9 +613,14 @@ public class TaskTest {
     }
 
     @Override
-    public void close()
-        throws IOException {
+    public void close() {
+    }
 
+    @Override
+    public void shutdown() {
+      if (!this._shouldIgnoreShutdown) {
+        this._shutdownRequested.set(true);
+      }
     }
   }
 
@@ -499,8 +643,7 @@ public class TaskTest {
     }
 
     @Override
-    public void init(WorkUnitState workUnitState)
-        throws Exception {
+    public void init(WorkUnitState workUnitState) {
     }
 
     @Override
@@ -522,8 +665,7 @@ public class TaskTest {
     }
 
     @Override
-    public void close()
-        throws IOException {
+    public void close() {
 
     }
   }
@@ -551,8 +693,7 @@ public class TaskTest {
     }
 
     @Override
-    public void init(WorkUnitState workUnitState)
-        throws Exception {
+    public void init(WorkUnitState workUnitState) {
     }
 
     @Override
@@ -585,8 +726,7 @@ public class TaskTest {
     }
 
     @Override
-    public void close()
-        throws IOException {
+    public void close() {
 
     }
   }
@@ -600,24 +740,20 @@ public class TaskTest {
     }
 
     @Override
-    public DataWriter build()
-        throws IOException {
+    public DataWriter build() {
       return new DataWriter() {
         @Override
-        public void write(Object record)
-            throws IOException {
+        public void write(Object record) {
           _recordSink.add(record);
         }
 
         @Override
-        public void commit()
-            throws IOException {
+        public void commit() {
 
         }
 
         @Override
-        public void cleanup()
-            throws IOException {
+        public void cleanup() {
 
         }
 
@@ -627,17 +763,35 @@ public class TaskTest {
         }
 
         @Override
-        public long bytesWritten()
-            throws IOException {
+        public long bytesWritten() {
           return -1;
         }
 
         @Override
-        public void close()
-            throws IOException {
+        public void close() {
 
         }
       };
     }
   }
+
+  /**
+   * An extension of {@link Task} that introduces a fixed delay on 
encountering an exception.
+   */
+  private static class DelayedFailureTask extends Task {
+    public DelayedFailureTask(TaskContext context, TaskStateTracker 
taskStateTracker, TaskExecutor taskExecutor,
+        Optional<CountDownLatch> countDownLatch) {
+      super(context, taskStateTracker, taskExecutor, countDownLatch);
+    }
+
+    @Override
+    protected void failTask(Throwable t) {
+      try {
+        Thread.sleep(1000);
+        super.failTask(t);
+      } catch (InterruptedException e) {
+        log.error("Encountered exception: {}", e);
+      }
+    }
+  }
 }

Reply via email to