cameronlee314 commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r432179780



##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -85,725 +62,484 @@
   private final TaskName taskName1 = new TaskName(p1.toString());
   private final SystemStreamPartition ssp0 = new 
SystemStreamPartition("testSystem", "testStream", p0);
   private final SystemStreamPartition ssp1 = new 
SystemStreamPartition("testSystem", "testStream", p1);
-  private final IncomingMessageEnvelope envelope0 = new 
IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-  private final IncomingMessageEnvelope envelope1 = new 
IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
-  private final IncomingMessageEnvelope envelope3 = new 
IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+  private final IncomingMessageEnvelope envelope00 = new 
IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelope10 = new 
IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
+  private final IncomingMessageEnvelope envelope01 = new 
IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
   private final IncomingMessageEnvelope ssp0EndOfStream = 
IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp0);
   private final IncomingMessageEnvelope ssp1EndOfStream = 
IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
 
-  TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, 
SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) {
-    TaskModel taskModel = mock(TaskModel.class);
-    when(taskModel.getTaskName()).thenReturn(taskName);
-    TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", 
new MetricsRegistryMap());
-    scala.collection.immutable.Set<SystemStreamPartition> sspSet = 
JavaConverters.asScalaSetConverter(Collections.singleton(ssp)).asScala().toSet();
-    return new TaskInstance(task,
-        taskModel,
-        taskInstanceMetrics,
-        null,
-        consumers,
-        mock(TaskInstanceCollector.class),
-        manager,
-        null,
-        null,
-        sspSet,
-        new TaskInstanceExceptionHandler(taskInstanceMetrics, new 
scala.collection.immutable.HashSet<String>()),
-        null,
-        null,
-        null,
-        null,
-        mock(JobContext.class),
-        mock(ContainerContext.class),
-        Option.apply(null),
-        Option.apply(null),
-        Option.apply(null));
-  }
-
-  interface TestCode {
-    void run(TaskCallback callback);
-  }
-
-  class TestTask implements AsyncStreamTask, WindowableTask, 
EndOfStreamListenerTask {
-    private final boolean shutdown;
-    private final boolean commit;
-    private final boolean success;
-    private final ExecutorService callbackExecutor = 
Executors.newFixedThreadPool(4);
-
-    private AtomicInteger completed = new AtomicInteger(0);
-    private TestCode callbackHandler = null;
-    private TestCode commitHandler = null;
-    private TaskCoordinator.RequestScope commitRequest = null;
-    private TaskCoordinator.RequestScope shutdownRequest = 
TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
-
-    private CountDownLatch processedMessagesLatch = null;
-
-    private volatile int windowCount = 0;
-    private volatile int processed = 0;
-    private volatile int committed = 0;
-
-    private int maxMessagesInFlight;
-
-    TestTask(boolean success, boolean commit, boolean shutdown, CountDownLatch 
processedMessagesLatch) {
-      this.success = success;
-      this.shutdown = shutdown;
-      this.commit = commit;
-      this.processedMessagesLatch = processedMessagesLatch;
-    }
-
-    TestTask(boolean success, boolean commit, boolean shutdown,
-             CountDownLatch processedMessagesLatch, int maxMessagesInFlight) {
-      this(success, commit, shutdown, processedMessagesLatch);
-      this.maxMessagesInFlight = maxMessagesInFlight;
-    }
-
-    @Override
-    public void processAsync(IncomingMessageEnvelope envelope, 
MessageCollector collector, TaskCoordinator coordinator, TaskCallback callback) 
{
-
-      if (maxMessagesInFlight == 1) {
-        assertEquals(processed, completed.get());
-      }
-
-      processed++;
-
-      if (commit) {
-        if (commitHandler != null) {
-          callbackExecutor.submit(() -> commitHandler.run(callback));
-        }
-        if (commitRequest != null) {
-          coordinator.commit(commitRequest);
-        }
-        committed++;
-      }
-
-      if (shutdown) {
-        coordinator.shutdown(shutdownRequest);
-      }
-
-      callbackExecutor.submit(() -> {
-          if (callbackHandler != null) {
-            callbackHandler.run(callback);
-          }
-
-          completed.incrementAndGet();
-
-          if (success) {
-            callback.complete();
-          } else {
-            callback.failure(new Exception("process failure"));
-          }
-
-          if (processedMessagesLatch != null) {
-            processedMessagesLatch.countDown();
-          }
-        });
-    }
-
-    @Override
-    public void window(MessageCollector collector, TaskCoordinator 
coordinator) throws Exception {
-      windowCount++;
-
-      if (shutdown && windowCount == 4) {
-        coordinator.shutdown(shutdownRequest);
-      }
-    }
-
-    @Override
-    public void onEndOfStream(MessageCollector collector, TaskCoordinator 
coordinator) {
-      coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-    }
-
-    void setShutdownRequest(TaskCoordinator.RequestScope shutdownRequest) {
-      this.shutdownRequest = shutdownRequest;
-    }
-
-    void setCommitRequest(TaskCoordinator.RequestScope commitRequest) {
-      this.commitRequest = commitRequest;
-    }
-  }
-
   @Rule
   public Timeout maxTestDurationInSeconds = Timeout.seconds(120);
 
   @Test
   public void testProcessMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task0Metrics = new TaskInstanceMetrics("test", new 
MetricsRegistryMap());
+    
when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(task0Metrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task1Metrics = new TaskInstanceMetrics("test", new 
MetricsRegistryMap());
+    
when(task1.systemStreamPartitions()).thenReturn(Collections.singleton(ssp1));
+    when(task1.metrics()).thenReturn(task1Metrics);
+    when(task0.taskName()).thenReturn(taskName1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
-    
when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    
when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task1).process(eq(envelope10), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(2L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
+    assertEquals(4L, containerMetrics.envelopes().getCount());
   }
 
   @Test
-  public void testProcessInOrder() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
+  public void testProcessInOrder() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
+    
when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics taskMetrics = new TaskInstanceMetrics("test", new 
MetricsRegistryMap());
+    
when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(taskMetrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
+    Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    
when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    // Wait till the tasks completes processing all the messages.
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
-    assertEquals(2L, t0.metrics().asyncCallbackCompleted().getCount());
-    assertEquals(1L, t1.metrics().asyncCallbackCompleted().getCount());
-  }
-
-  private TestCode buildOutofOrderCallback(final TestTask task) {
-    final CountDownLatch latch = new CountDownLatch(1);
-    return new TestCode() {
-      @Override
-      public void run(TaskCallback callback) {
-        IncomingMessageEnvelope envelope = ((TaskCallbackImpl) 
callback).getEnvelope();
-        if (envelope.equals(envelope0)) {
-          // process first message will wait till the second one is processed
-          try {
-            latch.await();
-          } catch (InterruptedException e) {
-            e.printStackTrace();
-          }
-        } else {
-          // second envelope complete first
-          assertEquals(0, task.completed.get());
-          latch.countDown();
-        }
-      }
-    };
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
   }
 
   @Test
-  public void testProcessOutOfOrder() throws Exception {
+  public void testProcessCallbacksCompletedOutOfOrder() {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, 
task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, false, false, 
task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, 
TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, 
TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    
when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    
when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
+
+    InOrder inOrderOffsetManager = inOrder(offsetManager);
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), 
eq(envelope00.getOffset()));
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), 
eq(envelope01.getOffset()));
 
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
+    assertEquals(3L, containerMetrics.envelopes().getCount());
+    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testWindow() throws Exception {
-    TestTask task0 = new TestTask(true, true, false, null);
-    TestTask task1 = new TestTask(true, false, true, null);
-
+  public void testWindow() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
+    int maxMessagesInFlight = 1;
+    long windowMs = 1;
+    RunLoopTask task = mock(RunLoopTask.class);
+    when(task.isWindowableTask()).thenReturn(true);
+
+    final AtomicInteger windowCount = new AtomicInteger(0);
+    doAnswer(x -> {
+        windowCount.incrementAndGet();
+        if (windowCount.get() == 4) {
+          x.getArgumentAt(0, 
ReadableCoordinator.class).shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        }
+        Thread.sleep(windowMs);
+        return null;
+      }).when(task).window(any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task);
 
-    long windowMs = 1;
-    int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false)).thenReturn(null);
     runLoop.run();
 
-    assertEquals(4, task1.windowCount);
+    verify(task, times(4)).window(any());
   }
 
   @Test
-  public void testCommitSingleTask() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitSingleTask() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, 
task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, 
task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, 
ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, 
TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+        
coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing 
and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    
when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), 
any(Checkpoint.class));
-    verify(offsetManager, never()).buildCheckpoint(eq(taskName1));
-    verify(offsetManager, never()).writeCheckpoint(eq(taskName1), 
any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1, never()).commit();
   }
 
   @Test
-  public void testCommitAllTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitAllTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, 
task0ProcessedMessagesLatch);
-    
task0.setCommitRequest(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
-    TestTask task1 = new TestTask(true, false, true, 
task1ProcessedMessagesLatch);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, 
ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, 
TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
+        
coordinator.commit(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+        
coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, 
() -> 0L, false);
     //have a null message in between to make sure task0 finishes processing 
and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    
when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
+
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), 
any(Checkpoint.class));
-    verify(offsetManager).buildCheckpoint(eq(taskName1));
-    verify(offsetManager).writeCheckpoint(eq(taskName1), 
any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1).commit();
   }
 
   @Test
-  public void testShutdownOnConsensus() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testShutdownOnConsensus() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, true, 
task0ProcessedMessagesLatch);
-    task0.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, 
task1ProcessedMessagesLatch);
-    task1.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer));
-    tasks.put(taskName1, createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer));
 
     int maxMessagesInFlight = 1;
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, 
ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, 
TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, 
ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, 
TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task1).process(eq(envelope10), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     // consensus is reached after envelope1 is processed.
-    
when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    
when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(2L, containerMetrics.envelopes().getCount());
     assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testEndOfStreamWithMultipleTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, 
task0ProcessedMessagesLatch);
-    TestTask task1 = new TestTask(true, true, false, 
task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false))
-      .thenReturn(envelope0)
-      .thenReturn(envelope1)
+      .thenReturn(envelope00)
+      .thenReturn(envelope10)
       .thenReturn(ssp0EndOfStream)
       .thenReturn(ssp1EndOfStream)
       .thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task0).endOfStream(any());
+
+    verify(task1).process(eq(envelope10), any(), any());
+    verify(task1).endOfStream(any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(4L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithOutOfOrderProcess() throws Exception {
+  public void testEndOfStreamWaitsForInFlightMessages() throws Exception {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, 
task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, true, false, 
task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(2);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, 
TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, 
TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    doAnswer(invocation -> {
+        assertEquals(0, task0.metrics().messagesInFlight().getValue());
+        return null;
+      }).when(task0).endOfStream(any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    
when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    
when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream)
+        .thenAnswer(invocation -> {
+            // this ensures that the end of stream message has passed through 
run loop BEFORE the last remaining in flight message completes
+            firstMessageBarrier.countDown();
+            return null;
+          });
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
+    verify(task0).endOfStream(any());
   }
 
   @Test
-  public void testEndOfStreamCommitBehavior() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testEndOfStreamCommitBehavior() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    //explicitly configure to disable commits inside process or window calls 
and invoke commit from end of stream
-    TestTask task0 = new TestTask(true, false, false, 
task0ProcessedMessagesLatch);
-    TestTask task1 = new TestTask(true, false, false, 
task1ProcessedMessagesLatch);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(0, 
ReadableCoordinator.class);
+
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+        return null;
+      }).when(task0).endOfStream(any());
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    
when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    
when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(ssp0EndOfStream).thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    InOrder inOrder = inOrder(task0);
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), 
any(Checkpoint.class));
-    verify(offsetManager).buildCheckpoint(eq(taskName1));
-    verify(offsetManager).writeCheckpoint(eq(taskName1), 
any(Checkpoint.class));
+    inOrder.verify(task0).endOfStream(any());
+    inOrder.verify(task0).commit();
   }
 
   @Test
-  public void testEndOfStreamOffsetManagement() throws Exception {
-    //explicitly configure to disable commits inside process or window calls 
and invoke commit from end of stream
-    TestTask mockStreamTask1 = new TestTask(true, false, false, null);
-    TestTask mockStreamTask2 = new TestTask(true, false, false, null);
-
-    Partition p1 = new Partition(1);
-    Partition p2 = new Partition(2);
-    SystemStreamPartition ssp1 = new SystemStreamPartition("system1", 
"stream1", p1);
-    SystemStreamPartition ssp2 = new SystemStreamPartition("system1", 
"stream2", p2);
-    IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp2, "1", 
"key1", "message1");
-    IncomingMessageEnvelope envelope2 = new IncomingMessageEnvelope(ssp2, "2", 
"key1", "message1");
-    IncomingMessageEnvelope envelope3 = 
IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp2);
-
-    Map<SystemStreamPartition, List<IncomingMessageEnvelope>> sspMap = new 
HashMap<>();
-    List<IncomingMessageEnvelope> messageList = new ArrayList<>();
-    messageList.add(envelope1);
-    messageList.add(envelope2);
-    messageList.add(envelope3);
-    sspMap.put(ssp2, messageList);
-
-    SystemConsumer mockConsumer = mock(SystemConsumer.class);
-    when(mockConsumer.poll(anyObject(), anyLong())).thenReturn(sspMap);
-
-    SystemAdmins systemAdmins = Mockito.mock(SystemAdmins.class);
-    
Mockito.when(systemAdmins.getSystemAdmin("system1")).thenReturn(Mockito.mock(SystemAdmin.class));
-    
Mockito.when(systemAdmins.getSystemAdmin("testSystem")).thenReturn(Mockito.mock(SystemAdmin.class));
-
-    HashMap<String, SystemConsumer> systemConsumerMap = new HashMap<>();
-    systemConsumerMap.put("system1", mockConsumer);
-
-    SystemConsumers consumers = 
TestSystemConsumers.getSystemConsumers(systemConsumerMap, systemAdmins);
-
-    TaskName taskName1 = new TaskName("task1");
-    TaskName taskName2 = new TaskName("task2");
-
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    when(offsetManager.getLastProcessedOffset(taskName1, 
ssp1)).thenReturn(Option.apply("3"));
-    when(offsetManager.getLastProcessedOffset(taskName2, 
ssp2)).thenReturn(Option.apply("0"));
-    when(offsetManager.getStartingOffset(taskName1, 
ssp1)).thenReturn(Option.apply(IncomingMessageEnvelope.END_OF_STREAM_OFFSET));
-    when(offsetManager.getStartingOffset(taskName2, 
ssp2)).thenReturn(Option.apply("1"));
-    when(offsetManager.getStartpoint(anyObject(), 
anyObject())).thenReturn(Option.empty());
-
-    TaskInstance taskInstance1 = createTaskInstance(mockStreamTask1, 
taskName1, ssp1, offsetManager, consumers);
-    TaskInstance taskInstance2 = createTaskInstance(mockStreamTask2, 
taskName2, ssp2, offsetManager, consumers);
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName1, taskInstance1);
-    tasks.put(taskName2, taskInstance2);
-
-    taskInstance1.registerConsumers();
-    taskInstance2.registerConsumers();
-    consumers.start();
-
-    int maxMessagesInFlight = 1;
-    RunLoop runLoop = new RunLoop(tasks, executor, consumers, 
maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-
-    runLoop.run();
-  }
-
-  //@Test
-  public void testCommitBehaviourWhenAsyncCommitIsEnabled() throws 
InterruptedException {
+  public void testCommitWithMessageInFlightWhenAsyncCommitIsEnabled() {
+    int maxMessagesInFlight = 2;
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(2);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    int maxMessagesInFlight = 3;
-    TestTask task0 = new TestTask(true, true, false, null, 
maxMessagesInFlight);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, false, null, 
maxMessagesInFlight);
-
-    IncomingMessageEnvelope firstMsg = new IncomingMessageEnvelope(ssp0, "0", 
"key0", "value0");
-    IncomingMessageEnvelope secondMsg = new IncomingMessageEnvelope(ssp0, "1", 
"key1", "value1");
-    IncomingMessageEnvelope thirdMsg = new IncomingMessageEnvelope(ssp0, "2", 
"key0", "value0");
-
-    final CountDownLatch firstMsgCompletionLatch = new CountDownLatch(1);
-    final CountDownLatch secondMsgCompletionLatch = new CountDownLatch(1);
-    task0.callbackHandler = callback -> {
-      IncomingMessageEnvelope envelope = ((TaskCallbackImpl) 
callback).getEnvelope();
-      try {
-        if (envelope.equals(firstMsg)) {
-          firstMsgCompletionLatch.await();
-        } else if (envelope.equals(secondMsg)) {
-          firstMsgCompletionLatch.countDown();
-          secondMsgCompletionLatch.await();
-        } else if (envelope.equals(thirdMsg)) {
-          secondMsgCompletionLatch.countDown();
-          // OffsetManager.update with firstMsg offset, task.commit has 
happened when second message callback has not completed.
-          verify(offsetManager).update(eq(taskName0), 
eq(firstMsg.getSystemStreamPartition()), eq(firstMsg.getOffset()));
-        }
-      } catch (Exception e) {
-        e.printStackTrace();
-      }
-    };
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer));
-    tasks.put(taskName1, createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer));
-    
when(consumerMultiplexer.choose(false)).thenReturn(firstMsg).thenReturn(secondMsg).thenReturn(thirdMsg).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
-
-    RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-
-    runLoop.run();
-
-    firstMsgCompletionLatch.await();
-    secondMsgCompletionLatch.await();
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, 
ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, 
TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    CountDownLatch secondMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, 
ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, 
TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        taskExecutor.submit(() -> {
+            // let the first message proceed to ask for a commit
+            firstMessageBarrier.countDown();
+            // block this message until commit is executed
+            secondMessageBarrier.await();
+            coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
 
-    verify(offsetManager, atLeastOnce()).buildCheckpoint(eq(taskName0));
-    verify(offsetManager, atLeastOnce()).writeCheckpoint(eq(taskName0), 
any(Checkpoint.class));
-    assertEquals(3, task0.processed);
-    assertEquals(3, task0.committed);
-    assertEquals(1, task1.processed);
-    assertEquals(0, task1.committed);
-  }
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().asyncCallbackCompleted().getCount());
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-  @Test
-  public void testProcessBehaviourWhenAsyncCommitIsEnabled() throws 
InterruptedException {
-    int maxMessagesInFlight = 2;
+        secondMessageBarrier.countDown();
+        return null;
+      }).when(task0).commit();
 
-    SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    TestTask task0 = new TestTask(true, true, false, null, 
maxMessagesInFlight);
-    CountDownLatch commitLatch = new CountDownLatch(1);
-    task0.commitHandler = callback -> {
-      TaskCallbackImpl taskCallback = (TaskCallbackImpl) callback;
-      if (taskCallback.getEnvelope().equals(envelope3)) {
-        try {
-          commitLatch.await();
-        } catch (InterruptedException e) {
-          e.printStackTrace();
-        }
-      }
-    };
-
-    task0.callbackHandler = callback -> {
-      TaskCallbackImpl taskCallback = (TaskCallbackImpl) callback;
-      if (taskCallback.getEnvelope().equals(envelope0)) {
-        // Both the process call has gone through when the first commit is in 
progress.
-        assertEquals(2, containerMetrics.processes().getCount());
-        assertEquals(0, containerMetrics.commits().getCount());
-        commitLatch.countDown();
-      }
-    };
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer));
-    
when(consumerMultiplexer.choose(false)).thenReturn(envelope3).thenReturn(envelope0).thenReturn(ssp0EndOfStream).thenReturn(null);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics,
-                                            () -> 0L, true);
-
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, 
() -> 0L, true);
+    
when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(null);
     runLoop.run();
 
-    commitLatch.await();
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
+    inOrder.verify(task0).commit();

Review comment:
       Oh, I see. I got confused between the `executor` and the `taskExecutor`. 
I was thinking the thread pool was running the `RunLoopTask.process` and 
`RunLoopTask.commit`, but that's not the case.
   My comment doesn't apply then.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to