bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r432168140



##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -85,725 +62,484 @@
   private final TaskName taskName1 = new TaskName(p1.toString());
   private final SystemStreamPartition ssp0 = new 
SystemStreamPartition("testSystem", "testStream", p0);
   private final SystemStreamPartition ssp1 = new 
SystemStreamPartition("testSystem", "testStream", p1);
-  private final IncomingMessageEnvelope envelope0 = new 
IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-  private final IncomingMessageEnvelope envelope1 = new 
IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
-  private final IncomingMessageEnvelope envelope3 = new 
IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+  private final IncomingMessageEnvelope envelope00 = new 
IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelope10 = new 
IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
+  private final IncomingMessageEnvelope envelope01 = new 
IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
   private final IncomingMessageEnvelope ssp0EndOfStream = 
IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp0);
   private final IncomingMessageEnvelope ssp1EndOfStream = 
IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
 
-  TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, 
SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) {
-    TaskModel taskModel = mock(TaskModel.class);
-    when(taskModel.getTaskName()).thenReturn(taskName);
-    TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", 
new MetricsRegistryMap());
-    scala.collection.immutable.Set<SystemStreamPartition> sspSet = 
JavaConverters.asScalaSetConverter(Collections.singleton(ssp)).asScala().toSet();
-    return new TaskInstance(task,
-        taskModel,
-        taskInstanceMetrics,
-        null,
-        consumers,
-        mock(TaskInstanceCollector.class),
-        manager,
-        null,
-        null,
-        sspSet,
-        new TaskInstanceExceptionHandler(taskInstanceMetrics, new 
scala.collection.immutable.HashSet<String>()),
-        null,
-        null,
-        null,
-        null,
-        mock(JobContext.class),
-        mock(ContainerContext.class),
-        Option.apply(null),
-        Option.apply(null),
-        Option.apply(null));
-  }
-
-  interface TestCode {
-    void run(TaskCallback callback);
-  }
-
-  class TestTask implements AsyncStreamTask, WindowableTask, 
EndOfStreamListenerTask {
-    private final boolean shutdown;
-    private final boolean commit;
-    private final boolean success;
-    private final ExecutorService callbackExecutor = 
Executors.newFixedThreadPool(4);
-
-    private AtomicInteger completed = new AtomicInteger(0);
-    private TestCode callbackHandler = null;
-    private TestCode commitHandler = null;
-    private TaskCoordinator.RequestScope commitRequest = null;
-    private TaskCoordinator.RequestScope shutdownRequest = 
TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
-
-    private CountDownLatch processedMessagesLatch = null;
-
-    private volatile int windowCount = 0;
-    private volatile int processed = 0;
-    private volatile int committed = 0;
-
-    private int maxMessagesInFlight;
-
-    TestTask(boolean success, boolean commit, boolean shutdown, CountDownLatch 
processedMessagesLatch) {
-      this.success = success;
-      this.shutdown = shutdown;
-      this.commit = commit;
-      this.processedMessagesLatch = processedMessagesLatch;
-    }
-
-    TestTask(boolean success, boolean commit, boolean shutdown,
-             CountDownLatch processedMessagesLatch, int maxMessagesInFlight) {
-      this(success, commit, shutdown, processedMessagesLatch);
-      this.maxMessagesInFlight = maxMessagesInFlight;
-    }
-
-    @Override
-    public void processAsync(IncomingMessageEnvelope envelope, 
MessageCollector collector, TaskCoordinator coordinator, TaskCallback callback) 
{
-
-      if (maxMessagesInFlight == 1) {
-        assertEquals(processed, completed.get());
-      }
-
-      processed++;
-
-      if (commit) {
-        if (commitHandler != null) {
-          callbackExecutor.submit(() -> commitHandler.run(callback));
-        }
-        if (commitRequest != null) {
-          coordinator.commit(commitRequest);
-        }
-        committed++;
-      }
-
-      if (shutdown) {
-        coordinator.shutdown(shutdownRequest);
-      }
-
-      callbackExecutor.submit(() -> {
-          if (callbackHandler != null) {
-            callbackHandler.run(callback);
-          }
-
-          completed.incrementAndGet();
-
-          if (success) {
-            callback.complete();
-          } else {
-            callback.failure(new Exception("process failure"));
-          }
-
-          if (processedMessagesLatch != null) {
-            processedMessagesLatch.countDown();
-          }
-        });
-    }
-
-    @Override
-    public void window(MessageCollector collector, TaskCoordinator 
coordinator) throws Exception {
-      windowCount++;
-
-      if (shutdown && windowCount == 4) {
-        coordinator.shutdown(shutdownRequest);
-      }
-    }
-
-    @Override
-    public void onEndOfStream(MessageCollector collector, TaskCoordinator 
coordinator) {
-      coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-    }
-
-    void setShutdownRequest(TaskCoordinator.RequestScope shutdownRequest) {
-      this.shutdownRequest = shutdownRequest;
-    }
-
-    void setCommitRequest(TaskCoordinator.RequestScope commitRequest) {
-      this.commitRequest = commitRequest;
-    }
-  }
-
   @Rule
   public Timeout maxTestDurationInSeconds = Timeout.seconds(120);
 
   @Test
   public void testProcessMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task0Metrics = new TaskInstanceMetrics("test", new 
MetricsRegistryMap());
+    
when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(task0Metrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task1Metrics = new TaskInstanceMetrics("test", new 
MetricsRegistryMap());
+    
when(task1.systemStreamPartitions()).thenReturn(Collections.singleton(ssp1));
+    when(task1.metrics()).thenReturn(task1Metrics);
+    when(task0.taskName()).thenReturn(taskName1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
-    
when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    
when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task1).process(eq(envelope10), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(2L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
+    assertEquals(4L, containerMetrics.envelopes().getCount());
   }
 
   @Test
-  public void testProcessInOrder() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
+  public void testProcessInOrder() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
+    
when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics taskMetrics = new TaskInstanceMetrics("test", new 
MetricsRegistryMap());
+    
when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(taskMetrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
+    Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    
when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    // Wait till the tasks completes processing all the messages.
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
-    assertEquals(2L, t0.metrics().asyncCallbackCompleted().getCount());
-    assertEquals(1L, t1.metrics().asyncCallbackCompleted().getCount());
-  }
-
-  private TestCode buildOutofOrderCallback(final TestTask task) {
-    final CountDownLatch latch = new CountDownLatch(1);
-    return new TestCode() {
-      @Override
-      public void run(TaskCallback callback) {
-        IncomingMessageEnvelope envelope = ((TaskCallbackImpl) 
callback).getEnvelope();
-        if (envelope.equals(envelope0)) {
-          // process first message will wait till the second one is processed
-          try {
-            latch.await();
-          } catch (InterruptedException e) {
-            e.printStackTrace();
-          }
-        } else {
-          // second envelope complete first
-          assertEquals(0, task.completed.get());
-          latch.countDown();
-        }
-      }
-    };
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
   }
 
   @Test
-  public void testProcessOutOfOrder() throws Exception {
+  public void testProcessCallbacksCompletedOutOfOrder() {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, 
task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, false, false, 
task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, 
TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, 
TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    
when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    
when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
+
+    InOrder inOrderOffsetManager = inOrder(offsetManager);
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), 
eq(envelope00.getOffset()));
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), 
eq(envelope01.getOffset()));
 
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
+    assertEquals(3L, containerMetrics.envelopes().getCount());
+    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testWindow() throws Exception {
-    TestTask task0 = new TestTask(true, true, false, null);
-    TestTask task1 = new TestTask(true, false, true, null);
-
+  public void testWindow() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
+    int maxMessagesInFlight = 1;
+    long windowMs = 1;
+    RunLoopTask task = mock(RunLoopTask.class);
+    when(task.isWindowableTask()).thenReturn(true);
+
+    final AtomicInteger windowCount = new AtomicInteger(0);
+    doAnswer(x -> {
+        windowCount.incrementAndGet();
+        if (windowCount.get() == 4) {
+          x.getArgumentAt(0, 
ReadableCoordinator.class).shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        }
+        Thread.sleep(windowMs);
+        return null;
+      }).when(task).window(any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task);
 
-    long windowMs = 1;
-    int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false)).thenReturn(null);
     runLoop.run();
 
-    assertEquals(4, task1.windowCount);
+    verify(task, times(4)).window(any());
   }
 
   @Test
-  public void testCommitSingleTask() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitSingleTask() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, 
task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, 
task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, 
ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, 
TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+        
coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing 
and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    
when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), 
any(Checkpoint.class));
-    verify(offsetManager, never()).buildCheckpoint(eq(taskName1));
-    verify(offsetManager, never()).writeCheckpoint(eq(taskName1), 
any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1, never()).commit();
   }
 
   @Test
-  public void testCommitAllTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitAllTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, 
task0ProcessedMessagesLatch);
-    
task0.setCommitRequest(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
-    TestTask task1 = new TestTask(true, false, true, 
task1ProcessedMessagesLatch);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, 
ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, 
TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
+        
coordinator.commit(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+        
coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, 
() -> 0L, false);
     //have a null message in between to make sure task0 finishes processing 
and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    
when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
+
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), 
any(Checkpoint.class));
-    verify(offsetManager).buildCheckpoint(eq(taskName1));
-    verify(offsetManager).writeCheckpoint(eq(taskName1), 
any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1).commit();
   }
 
   @Test
-  public void testShutdownOnConsensus() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testShutdownOnConsensus() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, true, 
task0ProcessedMessagesLatch);
-    task0.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, 
task1ProcessedMessagesLatch);
-    task1.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer));
-    tasks.put(taskName1, createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer));
 
     int maxMessagesInFlight = 1;
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, 
ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, 
TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, 
ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, 
TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task1).process(eq(envelope10), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     // consensus is reached after envelope1 is processed.
-    
when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    
when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(2L, containerMetrics.envelopes().getCount());
     assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testEndOfStreamWithMultipleTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, 
task0ProcessedMessagesLatch);
-    TestTask task1 = new TestTask(true, true, false, 
task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false))
-      .thenReturn(envelope0)
-      .thenReturn(envelope1)
+      .thenReturn(envelope00)
+      .thenReturn(envelope10)
       .thenReturn(ssp0EndOfStream)
       .thenReturn(ssp1EndOfStream)
       .thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task0).endOfStream(any());
+
+    verify(task1).process(eq(envelope10), any(), any());
+    verify(task1).endOfStream(any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(4L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithOutOfOrderProcess() throws Exception {
+  public void testEndOfStreamWaitsForInFlightMessages() throws Exception {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, 
task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, true, false, 
task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(2);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, 
TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, 
TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    doAnswer(invocation -> {
+        assertEquals(0, task0.metrics().messagesInFlight().getValue());
+        return null;
+      }).when(task0).endOfStream(any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, 
maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, 
maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    
when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    
when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream)
+        .thenAnswer(invocation -> {
+            // this ensures that the end of stream message has passed through 
run loop BEFORE the last remaining in flight message completes
+            firstMessageBarrier.countDown();
+            return null;
+          });
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());

Review comment:
       Good point. I think this should be added to the `Answer` for the mocked 
`endOfStream` call. When `endOfStream` is finally called, the task should have 
seen all messages at that point. If we have the check later, it might be 
possible for the test to pass with bad behavior, e.g. if `RunLoop` were to 
somehow touch the process metric after `endOfStream` is called.
   
   What do you think?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to