bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r432165558



##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -85,725 +62,484 @@
   private final TaskName taskName1 = new TaskName(p1.toString());
   private final SystemStreamPartition ssp0 = new 
SystemStreamPartition("testSystem", "testStream", p0);
   private final SystemStreamPartition ssp1 = new 
SystemStreamPartition("testSystem", "testStream", p1);
-  private final IncomingMessageEnvelope envelope0 = new 
IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-  private final IncomingMessageEnvelope envelope1 = new 
IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
-  private final IncomingMessageEnvelope envelope3 = new 
IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+  private final IncomingMessageEnvelope envelope00 = new 
IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelope10 = new 
IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
+  private final IncomingMessageEnvelope envelope01 = new 
IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
   private final IncomingMessageEnvelope ssp0EndOfStream = 
IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp0);
   private final IncomingMessageEnvelope ssp1EndOfStream = 
IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
 
-  TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, 
SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) {
-    TaskModel taskModel = mock(TaskModel.class);
-    when(taskModel.getTaskName()).thenReturn(taskName);
-    TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", 
new MetricsRegistryMap());
-    scala.collection.immutable.Set<SystemStreamPartition> sspSet = 
JavaConverters.asScalaSetConverter(Collections.singleton(ssp)).asScala().toSet();
-    return new TaskInstance(task,
-        taskModel,
-        taskInstanceMetrics,
-        null,
-        consumers,
-        mock(TaskInstanceCollector.class),
-        manager,
-        null,
-        null,
-        sspSet,
-        new TaskInstanceExceptionHandler(taskInstanceMetrics, new 
scala.collection.immutable.HashSet<String>()),
-        null,
-        null,
-        null,
-        null,
-        mock(JobContext.class),
-        mock(ContainerContext.class),
-        Option.apply(null),
-        Option.apply(null),
-        Option.apply(null));
-  }
-
-  interface TestCode {
-    void run(TaskCallback callback);
-  }
-
-  class TestTask implements AsyncStreamTask, WindowableTask, 
EndOfStreamListenerTask {
-    private final boolean shutdown;
-    private final boolean commit;
-    private final boolean success;
-    private final ExecutorService callbackExecutor = 
Executors.newFixedThreadPool(4);
-
-    private AtomicInteger completed = new AtomicInteger(0);
-    private TestCode callbackHandler = null;
-    private TestCode commitHandler = null;
-    private TaskCoordinator.RequestScope commitRequest = null;
-    private TaskCoordinator.RequestScope shutdownRequest = 
TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
-
-    private CountDownLatch processedMessagesLatch = null;
-
-    private volatile int windowCount = 0;
-    private volatile int processed = 0;
-    private volatile int committed = 0;
-
-    private int maxMessagesInFlight;
-
-    TestTask(boolean success, boolean commit, boolean shutdown, CountDownLatch 
processedMessagesLatch) {
-      this.success = success;
-      this.shutdown = shutdown;
-      this.commit = commit;
-      this.processedMessagesLatch = processedMessagesLatch;
-    }
-
-    TestTask(boolean success, boolean commit, boolean shutdown,
-             CountDownLatch processedMessagesLatch, int maxMessagesInFlight) {
-      this(success, commit, shutdown, processedMessagesLatch);
-      this.maxMessagesInFlight = maxMessagesInFlight;
-    }
-
-    @Override
-    public void processAsync(IncomingMessageEnvelope envelope, 
MessageCollector collector, TaskCoordinator coordinator, TaskCallback callback) 
{
-
-      if (maxMessagesInFlight == 1) {
-        assertEquals(processed, completed.get());
-      }
-
-      processed++;
-
-      if (commit) {
-        if (commitHandler != null) {
-          callbackExecutor.submit(() -> commitHandler.run(callback));
-        }
-        if (commitRequest != null) {
-          coordinator.commit(commitRequest);
-        }
-        committed++;
-      }
-
-      if (shutdown) {
-        coordinator.shutdown(shutdownRequest);
-      }
-
-      callbackExecutor.submit(() -> {
-          if (callbackHandler != null) {
-            callbackHandler.run(callback);
-          }
-
-          completed.incrementAndGet();
-
-          if (success) {
-            callback.complete();
-          } else {
-            callback.failure(new Exception("process failure"));
-          }
-
-          if (processedMessagesLatch != null) {
-            processedMessagesLatch.countDown();
-          }
-        });
-    }
-
-    @Override
-    public void window(MessageCollector collector, TaskCoordinator 
coordinator) throws Exception {
-      windowCount++;
-
-      if (shutdown && windowCount == 4) {
-        coordinator.shutdown(shutdownRequest);
-      }
-    }
-
-    @Override
-    public void onEndOfStream(MessageCollector collector, TaskCoordinator 
coordinator) {
-      coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-    }
-
-    void setShutdownRequest(TaskCoordinator.RequestScope shutdownRequest) {
-      this.shutdownRequest = shutdownRequest;
-    }
-
-    void setCommitRequest(TaskCoordinator.RequestScope commitRequest) {
-      this.commitRequest = commitRequest;
-    }
-  }
-
   @Rule
   public Timeout maxTestDurationInSeconds = Timeout.seconds(120);
 
   @Test
   public void testProcessMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, 
offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, 
offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task0Metrics = new TaskInstanceMetrics("test", new 
MetricsRegistryMap());
+    
when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(task0Metrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task1Metrics = new TaskInstanceMetrics("test", new 
MetricsRegistryMap());
+    
when(task1.systemStreamPartitions()).thenReturn(Collections.singleton(ssp1));
+    when(task1.metrics()).thenReturn(task1Metrics);
+    when(task0.taskName()).thenReturn(taskName1);

Review comment:
       Good catch. I should use `getMockRunLoopTask` here instead of creating 
the mock inline.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to