This is an automated email from the ASF dual-hosted git repository.
suvasude pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/gobblin.git
The following commit(s) were added to refs/heads/master by this push:
new f3f8a36 [GOBBLIN-1384] Fix task cancellation to ensure task commit is
invoked only after task completes[]
f3f8a36 is described below
commit f3f8a36559c7118f55c69f70f35294ba4a90a03d
Author: suvasude <[email protected]>
AuthorDate: Thu Feb 18 14:04:27 2021 -0800
[GOBBLIN-1384] Fix task cancellation to ensure task commit is invoked only
after task completes[]
Closes #3224 from sv2000/taskCancelCommit
---
.../main/java/org/apache/gobblin/runtime/Task.java | 14 +-
.../runtime/GobblinMultiTaskAttemptTest.java | 9 +-
.../java/org/apache/gobblin/runtime/TaskTest.java | 240 +++++++++++++++++----
3 files changed, 205 insertions(+), 58 deletions(-)
diff --git a/gobblin-runtime/src/main/java/org/apache/gobblin/runtime/Task.java
b/gobblin-runtime/src/main/java/org/apache/gobblin/runtime/Task.java
index e1fbc30..6b0e822 100644
--- a/gobblin-runtime/src/main/java/org/apache/gobblin/runtime/Task.java
+++ b/gobblin-runtime/src/main/java/org/apache/gobblin/runtime/Task.java
@@ -397,13 +397,9 @@ public class Task implements TaskIFace {
failTask(t);
} finally {
synchronized (this) {
- if (this.taskFuture == null || !this.taskFuture.isCancelled()) {
- this.taskStateTracker.onTaskRunCompletion(this);
- completeShutdown();
- this.taskFuture = null;
- } else {
- LOG.info("will not decrease count down latch as this task is
cancelled");
- }
+ this.taskStateTracker.onTaskRunCompletion(this);
+ completeShutdown();
+ this.taskFuture = null;
}
}
}
@@ -564,7 +560,7 @@ public class Task implements TaskIFace {
this.lastRecordPulledTimestampMillis = System.currentTimeMillis();
}
- private void failTask(Throwable t) {
+ protected void failTask(Throwable t) {
LOG.error(String.format("Task %s failed", this.taskId), t);
this.taskState.setWorkingState(WorkUnitState.WorkingState.FAILED);
this.taskState.setProp(ConfigurationKeys.TASK_FAILURE_EXCEPTION_KEY,
Throwables.getStackTraceAsString(t));
@@ -1039,8 +1035,6 @@ public class Task implements TaskIFace {
public synchronized boolean cancel() {
LOG.info("Calling task cancel with interrupt flag: {}",
this.shouldInterruptTaskOnCancel);
if (this.taskFuture != null &&
this.taskFuture.cancel(this.shouldInterruptTaskOnCancel)) {
- this.taskStateTracker.onTaskRunCompletion(this);
- this.completeShutdown();
return true;
} else {
return false;
diff --git
a/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/GobblinMultiTaskAttemptTest.java
b/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/GobblinMultiTaskAttemptTest.java
index b658801..79ed7f4 100644
---
a/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/GobblinMultiTaskAttemptTest.java
+++
b/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/GobblinMultiTaskAttemptTest.java
@@ -17,7 +17,6 @@
package org.apache.gobblin.runtime;
-import java.io.IOException;
import java.util.List;
import java.util.Properties;
@@ -95,7 +94,7 @@ public class GobblinMultiTaskAttemptTest {
@Test
public void testRunWithTaskStatsTrackerNotScheduledFailure()
throws Exception {
- TaskStateTracker stateTracker = new FailingTestStateTracker(new
Properties(), log);
+ TaskStateTracker stateTracker = new DummyTestStateTracker(new
Properties(), log);
// Preparing Instance of TaskAttempt with designed failure on task creation
WorkUnit tmpWU = new WorkUnit();
// Put necessary attributes in workunit
@@ -122,8 +121,8 @@ public class GobblinMultiTaskAttemptTest {
Assert.fail();
}
- public static class FailingTestStateTracker extends AbstractTaskStateTracker
{
- public FailingTestStateTracker(Properties properties, Logger logger) {
+ public static class DummyTestStateTracker extends AbstractTaskStateTracker {
+ public DummyTestStateTracker(Properties properties, Logger logger) {
super(properties, logger);
}
@@ -134,7 +133,7 @@ public class GobblinMultiTaskAttemptTest {
@Override
public void onTaskRunCompletion(Task task) {
-
+ task.markTaskCompletion();
}
@Override
diff --git
a/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/TaskTest.java
b/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/TaskTest.java
index 60f368c..8d91f8f 100644
--- a/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/TaskTest.java
+++ b/gobblin-runtime/src/test/java/org/apache/gobblin/runtime/TaskTest.java
@@ -17,13 +17,6 @@
package org.apache.gobblin.runtime;
-import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.anyInt;
-import static org.mockito.Mockito.doNothing;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.spy;
-import static org.mockito.Mockito.when;
-
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
@@ -34,9 +27,11 @@ import java.util.Properties;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
-import org.apache.gobblin.runtime.util.TaskMetrics;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.testng.Assert;
@@ -59,13 +54,21 @@ import
org.apache.gobblin.qualitychecker.row.RowLevelPolicyCheckResults;
import org.apache.gobblin.qualitychecker.row.RowLevelPolicyChecker;
import org.apache.gobblin.qualitychecker.task.TaskLevelPolicyCheckResults;
import org.apache.gobblin.qualitychecker.task.TaskLevelPolicyChecker;
-import org.apache.gobblin.source.extractor.DataRecordException;
+import org.apache.gobblin.runtime.util.TaskMetrics;
import org.apache.gobblin.source.extractor.Extractor;
import org.apache.gobblin.source.workunit.Extract;
import org.apache.gobblin.source.workunit.WorkUnit;
+import org.apache.gobblin.testing.AssertWithBackoff;
import org.apache.gobblin.writer.DataWriter;
import org.apache.gobblin.writer.DataWriterBuilder;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+
/**
* Integration tests for {@link Task}.
@@ -292,8 +295,6 @@ public class TaskTest {
}
-
-
private ArrayList<ArrayList<Object>> runTaskAndGetResults(TaskState
taskState, int numRecords, int numForks,
ForkOperator mockForkOperator)
throws Exception {
@@ -399,6 +400,139 @@ public class TaskTest {
}
/**
+ * A test that calls {@link Task#cancel()} while {@link Task#run()} is
executing. Ensures that the countdown latch
+ * is decremented and TaskState is set to FAILED.
+ * @throws Exception
+ */
+ @Test
+ public void testTaskCancelBeforeCompletion()
+ throws Exception {
+ // Create a TaskState
+ TaskState taskState = getEmptyTestTaskState("testCancelBeforeCompletion");
+
+ int numRecords = -1;
+ int numForks = 1;
+ ForkOperator mockForkOperator = new RoundRobinForkOperator(numForks);
+
+ ArrayList<ArrayList<Object>> recordCollectors = new ArrayList<>(numForks);
+ for (int i=0; i < numForks; ++i) {
+ recordCollectors.add(new ArrayList<>());
+ }
+
+ TaskContext mockTaskContext = getMockTaskContext(taskState,
+ new StringExtractor(numRecords), recordCollectors, mockForkOperator);
+
+ // Create a dummy TaskStateTracker
+ TaskStateTracker dummyTaskStateTracker = new
GobblinMultiTaskAttemptTest.DummyTestStateTracker(new Properties(), log);
+
+ // Create a TaskExecutor - a real TaskExecutor must be created so a Fork
is run in a separate thread
+ TaskExecutor taskExecutor = new TaskExecutor(new Properties());
+
+ CountUpAndDownLatch countDownLatch = new CountUpAndDownLatch(0);
+ // Create the Task
+ Task task = new DelayedFailureTask(mockTaskContext, dummyTaskStateTracker,
taskExecutor, Optional.of(countDownLatch));
+ //Increment the countDownLatch to signal a new task creation.
+ countDownLatch.countUp();
+
+ ExecutorService executorService = Executors.newSingleThreadExecutor();
+ Future taskFuture = executorService.submit(new Thread(() -> task.run()));
+ task.setTaskFuture(taskFuture);
+
+ //Wait for task to enter RUNNING state
+ AssertWithBackoff.create().maxSleepMs(10).timeoutMs(1000).backoffFactor(1)
+ .assertTrue(input -> task.getWorkingState() ==
WorkUnitState.WorkingState.RUNNING,
+ "Waiting for task to enter RUNNING state");
+
+ Assert.assertEquals(countDownLatch.getCount(), 1);
+
+ task.shutdown();
+
+ //Ensure task is still RUNNING, since shutdown() is a NO-OP and the
extractor should continue.
+ Assert.assertEquals(countDownLatch.getCount(), 1);
+ Assert.assertEquals(taskState.getWorkingState(),
WorkUnitState.WorkingState.RUNNING);
+
+ //Call task cancel
+ task.cancel();
+
+ //Ensure task is still RUNNING immediately after cancel() due to the delay
introduced in task failure handling.
+ Assert.assertEquals(countDownLatch.getCount(), 1);
+ Assert.assertEquals(taskState.getWorkingState(),
WorkUnitState.WorkingState.RUNNING);
+
+ //Ensure countDownLatch is eventually counted down to 0
+ AssertWithBackoff.create().maxSleepMs(100).timeoutMs(5000).backoffFactor(1)
+ .assertTrue(input -> countDownLatch.getCount() == 0, "Waiting for the
task to complete.");
+
+ //Ensure the TaskState is set to FAILED
+ Assert.assertEquals(taskState.getWorkingState(),
WorkUnitState.WorkingState.FAILED);
+ }
+
+ /**
+ * A test that calls {@link Task#cancel()} after {@link Task#run()} is
completed. In this case the cancel() method should
+ * be a NO-OP and should leave the task state unchanged.
+ * @throws Exception
+ */
+ @Test
+ public void testTaskCancelAfterCompletion()
+ throws Exception {
+ // Create a TaskState
+ TaskState taskState = getEmptyTestTaskState("testCancelAfterCompletion");
+
+ int numRecords = -1;
+ int numForks = 1;
+ ForkOperator mockForkOperator = new RoundRobinForkOperator(numForks);
+
+ ArrayList<ArrayList<Object>> recordCollectors = new ArrayList<>(numForks);
+ for (int i=0; i < numForks; ++i) {
+ recordCollectors.add(new ArrayList<>());
+ }
+
+ TaskContext mockTaskContext = getMockTaskContext(taskState,
+ new StringExtractor(numRecords, false), recordCollectors,
mockForkOperator);
+
+ // Create a dummy TaskStateTracker
+ TaskStateTracker dummyTaskStateTracker = new
GobblinMultiTaskAttemptTest.DummyTestStateTracker(new Properties(), log);
+
+ // Create a TaskExecutor - a real TaskExecutor must be created so a Fork
is run in a separate thread
+ TaskExecutor taskExecutor = new TaskExecutor(new Properties());
+
+ CountUpAndDownLatch countDownLatch = new CountUpAndDownLatch(0);
+ // Create the Task
+ Task task = new Task(mockTaskContext, dummyTaskStateTracker, taskExecutor,
Optional.of(countDownLatch));
+ //Increment the countDownLatch to signal a new task creation.
+ countDownLatch.countUp();
+
+ ExecutorService executorService = Executors.newSingleThreadExecutor();
+ Future taskFuture = executorService.submit(new Thread(() -> task.run()));
+ task.setTaskFuture(taskFuture);
+
+ //Wait for task to enter RUNNING state
+ AssertWithBackoff.create().maxSleepMs(10).timeoutMs(1000).backoffFactor(1)
+ .assertTrue(input -> task.getWorkingState() ==
WorkUnitState.WorkingState.RUNNING,
+ "Waiting for task to enter RUNNING state");
+
+ Assert.assertEquals(countDownLatch.getCount(), 1);
+
+ task.shutdown();
+
+ //Ensure countDownLatch is counted down to 0 i.e. task is done.
+ AssertWithBackoff.create().maxSleepMs(100).timeoutMs(5000).backoffFactor(1)
+ .assertTrue(input -> countDownLatch.getCount() == 0, "Waiting for the
task to complete.");
+
+ //Ensure the TaskState is RUNNING
+ Assert.assertEquals(taskState.getWorkingState(),
WorkUnitState.WorkingState.RUNNING);
+
+ //Call task cancel
+ task.cancel();
+
+ //Ensure the TaskState is unchanged on cancel()
+ Assert.assertEquals(taskState.getWorkingState(),
WorkUnitState.WorkingState.RUNNING);
+
+ //Ensure task state is successful on commit()
+ task.commit();
+ Assert.assertEquals(taskState.getWorkingState(),
WorkUnitState.WorkingState.SUCCESSFUL);
+ }
+
+ /**
* An implementation of {@link Extractor} that throws an {@link IOException}
during the invocation of
* {@link #readRecord(Object)}.
*/
@@ -407,12 +541,12 @@ public class TaskTest {
private final AtomicBoolean HAS_FAILED = new AtomicBoolean();
@Override
- public Object getSchema() throws IOException {
+ public Object getSchema() {
return null;
}
@Override
- public Object readRecord(@Deprecated Object reuse) throws
DataRecordException, IOException {
+ public Object readRecord(@Deprecated Object reuse) throws IOException {
if (!HAS_FAILED.get()) {
HAS_FAILED.set(true);
throw new IOException("Injected failure");
@@ -431,31 +565,36 @@ public class TaskTest {
}
@Override
- public void close() throws IOException {
+ public void close() {
// Do nothing
}
}
-
private static class StringExtractor implements Extractor<Object, String> {
-
+ //Num records to extract. If set to -1, it is treated as an unbounded
extractor.
private final int _numRecords;
private int _currentRecord;
+ private boolean _shouldIgnoreShutdown = true;
+ private AtomicBoolean _shutdownRequested = new AtomicBoolean(false);
+
public StringExtractor(int numRecords) {
+ this(numRecords, true);
+ }
+
+ public StringExtractor(int numRecords, boolean shouldIgnoreShutdown) {
_numRecords = numRecords;
_currentRecord = -1;
+ _shouldIgnoreShutdown = shouldIgnoreShutdown;
}
@Override
- public Object getSchema()
- throws IOException {
+ public Object getSchema() {
return "";
}
@Override
- public String readRecord(@Deprecated String reuse)
- throws DataRecordException, IOException {
- if (_currentRecord < _numRecords-1) {
+ public String readRecord(@Deprecated String reuse) {
+ if (!_shutdownRequested.get() && (_numRecords == -1 || _currentRecord <
_numRecords-1)) {
_currentRecord++;
return "" + _currentRecord;
} else {
@@ -474,9 +613,14 @@ public class TaskTest {
}
@Override
- public void close()
- throws IOException {
+ public void close() {
+ }
+ @Override
+ public void shutdown() {
+ if (!this._shouldIgnoreShutdown) {
+ this._shutdownRequested.set(true);
+ }
}
}
@@ -499,8 +643,7 @@ public class TaskTest {
}
@Override
- public void init(WorkUnitState workUnitState)
- throws Exception {
+ public void init(WorkUnitState workUnitState) {
}
@Override
@@ -522,8 +665,7 @@ public class TaskTest {
}
@Override
- public void close()
- throws IOException {
+ public void close() {
}
}
@@ -551,8 +693,7 @@ public class TaskTest {
}
@Override
- public void init(WorkUnitState workUnitState)
- throws Exception {
+ public void init(WorkUnitState workUnitState) {
}
@Override
@@ -585,8 +726,7 @@ public class TaskTest {
}
@Override
- public void close()
- throws IOException {
+ public void close() {
}
}
@@ -600,24 +740,20 @@ public class TaskTest {
}
@Override
- public DataWriter build()
- throws IOException {
+ public DataWriter build() {
return new DataWriter() {
@Override
- public void write(Object record)
- throws IOException {
+ public void write(Object record) {
_recordSink.add(record);
}
@Override
- public void commit()
- throws IOException {
+ public void commit() {
}
@Override
- public void cleanup()
- throws IOException {
+ public void cleanup() {
}
@@ -627,17 +763,35 @@ public class TaskTest {
}
@Override
- public long bytesWritten()
- throws IOException {
+ public long bytesWritten() {
return -1;
}
@Override
- public void close()
- throws IOException {
+ public void close() {
}
};
}
}
+
+ /**
+ * An extension of {@link Task} that introduces a fixed delay on
encountering an exception.
+ */
+ private static class DelayedFailureTask extends Task {
+ public DelayedFailureTask(TaskContext context, TaskStateTracker
taskStateTracker, TaskExecutor taskExecutor,
+ Optional<CountDownLatch> countDownLatch) {
+ super(context, taskStateTracker, taskExecutor, countDownLatch);
+ }
+
+ @Override
+ protected void failTask(Throwable t) {
+ try {
+ Thread.sleep(1000);
+ super.failTask(t);
+ } catch (InterruptedException e) {
+ log.error("Encountered exception: {}", e);
+ }
+ }
+ }
}