Merging TaskFactory changes

Project: http://git-wip-us.apache.org/repos/asf/samza/repo
Commit: http://git-wip-us.apache.org/repos/asf/samza/commit/be6be8cb
Tree: http://git-wip-us.apache.org/repos/asf/samza/tree/be6be8cb
Diff: http://git-wip-us.apache.org/repos/asf/samza/diff/be6be8cb

Branch: refs/heads/samza-standalone
Commit: be6be8cba1a28a8395ca9c6700eb125093f6c517
Parents: 7d6332b
Author: navina <[email protected]>
Authored: Fri Dec 23 16:59:07 2016 -0800
Committer: navina <[email protected]>
Committed: Fri Dec 23 16:59:07 2016 -0800

----------------------------------------------------------------------
 .../samza/task/AsyncStreamTaskFactory.java      |  28 ++++
 .../apache/samza/task/StreamTaskFactory.java    |  28 ++++
 .../apache/samza/container/RunLoopFactory.java  |  20 ++-
 .../coordinator/JobCoordinatorFactory.java      |   3 +
 .../apache/samza/processor/StreamProcessor.java |  27 +++-
 .../org/apache/samza/task/AsyncRunLoop.java     |  37 ++---
 .../main/java/org/apache/samza/zk/ZkUtils.java  |   2 +-
 .../org/apache/samza/container/RunLoop.scala    |  18 +--
 .../apache/samza/container/SamzaContainer.scala | 137 ++++++++++++-------
 .../apache/samza/container/TaskInstance.scala   |  21 +--
 .../samza/job/local/ThreadJobFactory.scala      |  12 +-
 .../org/apache/samza/task/TestAsyncRunLoop.java |  20 +--
 .../apache/samza/container/TestRunLoop.scala    |  15 +-
 .../samza/container/TestSamzaContainer.scala    |   4 +-
 .../samza/container/TestTaskInstance.scala      |   6 +-
 15 files changed, 251 insertions(+), 127 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-api/src/main/java/org/apache/samza/task/AsyncStreamTaskFactory.java
----------------------------------------------------------------------
diff --git 
a/samza-api/src/main/java/org/apache/samza/task/AsyncStreamTaskFactory.java 
b/samza-api/src/main/java/org/apache/samza/task/AsyncStreamTaskFactory.java
new file mode 100644
index 0000000..e5ce9c4
--- /dev/null
+++ b/samza-api/src/main/java/org/apache/samza/task/AsyncStreamTaskFactory.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+/**
+ * Build {@link AsyncStreamTask} instances.
+ * Implementations should return a new instance for each {@link 
#createInstance()} invocation.
+ */
+public interface AsyncStreamTaskFactory {
+  AsyncStreamTask createInstance();
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-api/src/main/java/org/apache/samza/task/StreamTaskFactory.java
----------------------------------------------------------------------
diff --git 
a/samza-api/src/main/java/org/apache/samza/task/StreamTaskFactory.java 
b/samza-api/src/main/java/org/apache/samza/task/StreamTaskFactory.java
new file mode 100644
index 0000000..ec53bc0
--- /dev/null
+++ b/samza-api/src/main/java/org/apache/samza/task/StreamTaskFactory.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+/**
+ * Build {@link StreamTask} instances.
+ * Implementations should return a new instance for each {@link 
#createInstance()} invocation.
+ */
+public interface StreamTaskFactory {
+  StreamTask createInstance();
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java 
b/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
index 23a68cb..32ab47a 100644
--- a/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
+++ b/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
@@ -19,20 +19,19 @@
 
 package org.apache.samza.container;
 
-import java.util.concurrent.ExecutorService;
 import org.apache.samza.SamzaException;
 import org.apache.samza.config.TaskConfig;
-import org.apache.samza.util.HighResolutionClock;
 import org.apache.samza.system.SystemConsumers;
 import org.apache.samza.task.AsyncRunLoop;
-import org.apache.samza.task.AsyncStreamTask;
-import org.apache.samza.task.StreamTask;
+import org.apache.samza.util.HighResolutionClock;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import scala.collection.JavaConversions;
 import scala.runtime.AbstractFunction0;
 import scala.runtime.AbstractFunction1;
 
+import java.util.concurrent.ExecutorService;
+
 import static org.apache.samza.util.Util.asScalaClock;
 
 /**
@@ -46,7 +45,7 @@ public class RunLoopFactory {
   private static final long DEFAULT_COMMIT_MS = 60000L;
   private static final long DEFAULT_CALLBACK_TIMEOUT_MS = -1L;
 
-  public static Runnable 
createRunLoop(scala.collection.immutable.Map<TaskName, TaskInstance<?>> 
taskInstances,
+  public static Runnable 
createRunLoop(scala.collection.immutable.Map<TaskName, TaskInstance> 
taskInstances,
       SystemConsumers consumerMultiplexer,
       ExecutorService threadPool,
       long maxThrottlingDelayMs,
@@ -62,9 +61,9 @@ public class RunLoopFactory {
 
     log.info("Got commit milliseconds: " + taskCommitMs);
 
-    int asyncTaskCount = taskInstances.values().count(new 
AbstractFunction1<TaskInstance<?>, Object>() {
+    int asyncTaskCount = taskInstances.values().count(new 
AbstractFunction1<TaskInstance, Object>() {
       @Override
-      public Boolean apply(TaskInstance<?> t) {
+      public Boolean apply(TaskInstance t) {
         return t.isAsyncTask();
       }
     });
@@ -77,9 +76,8 @@ public class RunLoopFactory {
     if (asyncTaskCount == 0) {
       log.info("Run loop in single thread mode.");
 
-      scala.collection.immutable.Map<TaskName, TaskInstance<StreamTask>> 
streamTaskInstances = (scala.collection.immutable.Map) taskInstances;
       return new RunLoop(
-        streamTaskInstances,
+        taskInstances,
         consumerMultiplexer,
         containerMetrics,
         maxThrottlingDelayMs,
@@ -95,12 +93,10 @@ public class RunLoopFactory {
 
       log.info("Got callback timeout: " + callbackTimeout);
 
-      scala.collection.immutable.Map<TaskName, TaskInstance<AsyncStreamTask>> 
asyncStreamTaskInstances = (scala.collection.immutable.Map) taskInstances;
-
       log.info("Run loop in asynchronous mode.");
 
       return new AsyncRunLoop(
-        JavaConversions.mapAsJavaMap(asyncStreamTaskInstances),
+        JavaConversions.mapAsJavaMap(taskInstances),
         threadPool,
         consumerMultiplexer,
         taskMaxConcurrency,

http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/main/java/org/apache/samza/coordinator/JobCoordinatorFactory.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/main/java/org/apache/samza/coordinator/JobCoordinatorFactory.java
 
b/samza-core/src/main/java/org/apache/samza/coordinator/JobCoordinatorFactory.java
index af2aaa7..e12e49f 100644
--- 
a/samza-core/src/main/java/org/apache/samza/coordinator/JobCoordinatorFactory.java
+++ 
b/samza-core/src/main/java/org/apache/samza/coordinator/JobCoordinatorFactory.java
@@ -23,7 +23,10 @@ import org.apache.samza.processor.SamzaContainerController;
 
 public interface JobCoordinatorFactory {
   /**
+   * @param processorId Unique identifier for the processor
    * @param config Configs relevant for the JobCoordinator TODO: Separate JC 
related configs into a "JobCoordinatorConfig"
+   * @param containerController Controller interface for starting and stopping 
container. In future, it may simply
+   *                            pause the container and add/remove tasks
    * @return An instance of IJobCoordinator
    */
   JobCoordinator getJobCoordinator(int processorId, Config config, 
SamzaContainerController containerController);

http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java 
b/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
index 0f34400..61795e1 100644
--- a/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
+++ b/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
@@ -25,6 +25,8 @@ import org.apache.samza.config.TaskConfigJava;
 import org.apache.samza.coordinator.JobCoordinator;
 import org.apache.samza.coordinator.JobCoordinatorFactory;
 import org.apache.samza.metrics.MetricsReporter;
+import org.apache.samza.task.AsyncStreamTaskFactory;
+import org.apache.samza.task.StreamTaskFactory;
 import org.apache.samza.util.Util;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -34,11 +36,11 @@ import java.util.Map;
 
 /**
  * StreamProcessor can be embedded in any application or executed in a 
distributed environment (aka cluster) as
- * independent processes <br />
+ * independent processes
  * <p>
  * <b>Usage Example:</b>
  * <pre>
- * StreamProcessor processor = new StreamProcessor(1, config); <br />
+ * StreamProcessor processor = new StreamProcessor(1, config);
  * processor.start();
  * try {
  *  boolean status = processor.awaitStart(TIMEOUT_MS);    // Optional - 
blocking call
@@ -74,7 +76,7 @@ public class StreamProcessor {
    * JobCoordinator controls how the various StreamProcessor instances 
belonging to a job coordinate. It is also
    * responsible generating and updating JobModel.
    * When StreamProcessor starts, it starts the JobCoordinator and brings up a 
SamzaContainer based on the JobModel.
-   * SamzaContainer is executed using an ExecutorService. <br />
+   * SamzaContainer is executed using an ExecutorService.
    * <p>
    * <b>Note:</b> Lifecycle of the ExecutorService is fully managed by the 
StreamProcessor, and NOT exposed to the user
    *
@@ -82,6 +84,25 @@ public class StreamProcessor {
    *                               "containerId" in Samza
    * @param config                 Instance of config object - contains all 
configuration required for processing
    * @param customMetricsReporters Map of custom MetricReporter instances that 
are to be injected in the Samza job
+   * @param asyncStreamTaskFactory The {@link AsyncStreamTaskFactory} to be 
used for creating task instances.
+   */
+  public StreamProcessor(int processorId, Config config, Map<String, 
MetricsReporter> customMetricsReporters,
+                         AsyncStreamTaskFactory asyncStreamTaskFactory) {
+    this(processorId, config, customMetricsReporters, (Object) 
asyncStreamTaskFactory);
+  }
+
+  /**
+   * Same as {@link #StreamProcessor(int, Config, Map, 
AsyncStreamTaskFactory)}, except task instances are created
+   * using the provided {@link StreamTaskFactory}.
+   */
+  public StreamProcessor(int processorId, Config config, Map<String, 
MetricsReporter> customMetricsReporters,
+                         StreamTaskFactory streamTaskFactory) {
+    this(processorId, config, customMetricsReporters, (Object) 
streamTaskFactory);
+  }
+
+  /**
+   * Same as {@link #StreamProcessor(int, Config, Map, 
AsyncStreamTaskFactory)}, except task instances are created
+   * using the "task.class" configuration instead of a task factory.
    */
   public StreamProcessor(int processorId, Config config, Map<String, 
MetricsReporter> customMetricsReporters) {
     this(processorId, config, customMetricsReporters, (Object) null);

http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/main/java/org/apache/samza/task/AsyncRunLoop.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/task/AsyncRunLoop.java 
b/samza-core/src/main/java/org/apache/samza/task/AsyncRunLoop.java
index 16f1563..ab5514c 100644
--- a/samza-core/src/main/java/org/apache/samza/task/AsyncRunLoop.java
+++ b/samza-core/src/main/java/org/apache/samza/task/AsyncRunLoop.java
@@ -75,7 +75,7 @@ public class AsyncRunLoop implements Runnable, Throttleable {
   private volatile Throwable throwable = null;
   private final HighResolutionClock clock;
 
-  public AsyncRunLoop(Map<TaskName, TaskInstance<AsyncStreamTask>> 
taskInstances,
+  public AsyncRunLoop(Map<TaskName, TaskInstance> taskInstances,
       ExecutorService threadPool,
       SystemConsumers consumerMultiplexer,
       int maxConcurrency,
@@ -100,7 +100,7 @@ public class AsyncRunLoop implements Runnable, Throttleable 
{
     this.workerTimer = Executors.newSingleThreadScheduledExecutor();
     this.clock = clock;
     Map<TaskName, AsyncTaskWorker> workers = new HashMap<>();
-    for (TaskInstance<AsyncStreamTask> task : taskInstances.values()) {
+    for (TaskInstance task : taskInstances.values()) {
       workers.put(task.taskName(), new AsyncTaskWorker(task));
     }
     // Partions and tasks assigned to the container will not change during the 
run loop life time
@@ -112,14 +112,12 @@ public class AsyncRunLoop implements Runnable, 
Throttleable {
    * Returns mapping of the SystemStreamPartition to the AsyncTaskWorkers to 
efficiently route the envelopes
    */
   private static Map<SystemStreamPartition, List<AsyncTaskWorker>> 
getSspToAsyncTaskWorkerMap(
-      Map<TaskName, TaskInstance<AsyncStreamTask>> taskInstances, 
Map<TaskName, AsyncTaskWorker> taskWorkers) {
+      Map<TaskName, TaskInstance> taskInstances, Map<TaskName, 
AsyncTaskWorker> taskWorkers) {
     Map<SystemStreamPartition, List<AsyncTaskWorker>> sspToWorkerMap = new 
HashMap<>();
-    for (TaskInstance<AsyncStreamTask> task : taskInstances.values()) {
+    for (TaskInstance task : taskInstances.values()) {
       Set<SystemStreamPartition> ssps = 
JavaConversions.setAsJavaSet(task.systemStreamPartitions());
       for (SystemStreamPartition ssp : ssps) {
-        if (sspToWorkerMap.get(ssp) == null) {
-          sspToWorkerMap.put(ssp, new ArrayList<AsyncTaskWorker>());
-        }
+        sspToWorkerMap.putIfAbsent(ssp, new ArrayList<>());
         sspToWorkerMap.get(ssp).add(taskWorkers.get(task.taskName()));
       }
     }
@@ -202,7 +200,8 @@ public class AsyncRunLoop implements Runnable, Throttleable 
{
   private IncomingMessageEnvelope chooseEnvelope() {
     IncomingMessageEnvelope envelope = consumerMultiplexer.choose(false);
     if (envelope != null) {
-      log.trace("Choose envelope ssp {} offset {} for processing", 
envelope.getSystemStreamPartition(), envelope.getOffset());
+      log.trace("Choose envelope ssp {} offset {} for processing",
+          envelope.getSystemStreamPartition(), envelope.getOffset());
       containerMetrics.envelopes().inc();
     } else {
       log.trace("No envelope is available");
@@ -310,12 +309,12 @@ public class AsyncRunLoop implements Runnable, 
Throttleable {
    * will run the task asynchronously. It runs window and commit in the 
provided thread pool.
    */
   private class AsyncTaskWorker implements TaskCallbackListener {
-    private final TaskInstance<AsyncStreamTask> task;
+    private final TaskInstance task;
     private final TaskCallbackManager callbackManager;
     private volatile AsyncTaskState state;
 
 
-    AsyncTaskWorker(TaskInstance<AsyncStreamTask> task) {
+    AsyncTaskWorker(TaskInstance task) {
       this.task = task;
       this.callbackManager = new TaskCallbackManager(this, callbackTimer, 
callbackTimeoutMs, maxConcurrency, clock);
       Set<SystemStreamPartition> sspSet = getWorkingSSPSet(task);
@@ -352,12 +351,14 @@ public class AsyncRunLoop implements Runnable, 
Throttleable {
      * @param task
      * @return a Set of SSPs such that all SSPs are not at end of stream.
      */
-    private Set<SystemStreamPartition> 
getWorkingSSPSet(TaskInstance<AsyncStreamTask> task) {
+    private Set<SystemStreamPartition> getWorkingSSPSet(TaskInstance task) {
 
       Set<SystemStreamPartition> allPartitions = new 
HashSet<>(JavaConversions.setAsJavaSet(task.systemStreamPartitions()));
 
       // filter only those SSPs that are not at end of stream.
-      Set<SystemStreamPartition> workingSSPSet = 
allPartitions.stream().filter(ssp -> 
!consumerMultiplexer.isEndOfStream(ssp)).collect(Collectors.toSet());
+      Set<SystemStreamPartition> workingSSPSet = allPartitions.stream()
+          .filter(ssp -> !consumerMultiplexer.isEndOfStream(ssp))
+          .collect(Collectors.toSet());
       return workingSSPSet;
     }
 
@@ -512,15 +513,18 @@ public class AsyncRunLoop implements Runnable, 
Throttleable {
             state.doneProcess();
             TaskCallbackImpl callbackImpl = (TaskCallbackImpl) callback;
             containerMetrics.processNs().update(clock.nanoTime() - 
callbackImpl.timeCreatedNs);
-            log.trace("Got callback complete for task {}, ssp {}", 
callbackImpl.taskName, callbackImpl.envelope.getSystemStreamPartition());
+            log.trace("Got callback complete for task {}, ssp {}",
+                callbackImpl.taskName, 
callbackImpl.envelope.getSystemStreamPartition());
 
             TaskCallbackImpl callbackToUpdate = 
callbackManager.updateCallback(callbackImpl);
             if (callbackToUpdate != null) {
               IncomingMessageEnvelope envelope = callbackToUpdate.envelope;
-              log.trace("Update offset for ssp {}, offset {}", 
envelope.getSystemStreamPartition(), envelope.getOffset());
+              log.trace("Update offset for ssp {}, offset {}",
+                  envelope.getSystemStreamPartition(), envelope.getOffset());
 
               // update offset
-              task.offsetManager().update(task.taskName(), 
envelope.getSystemStreamPartition(), envelope.getOffset());
+              task.offsetManager().update(task.taskName(),
+                  envelope.getSystemStreamPartition(), envelope.getOffset());
 
               // update coordinator
               coordinatorRequests.update(callbackToUpdate.coordinator);
@@ -697,7 +701,8 @@ public class AsyncRunLoop implements Runnable, Throttleable 
{
       PendingEnvelope pendingEnvelope = pendingEnvelopeQueue.remove();
       int queueSize = pendingEnvelopeQueue.size();
       taskMetrics.pendingMessages().set(queueSize);
-      log.trace("fetch envelope ssp {} offset {} to process.", 
pendingEnvelope.envelope.getSystemStreamPartition(), 
pendingEnvelope.envelope.getOffset());
+      log.trace("fetch envelope ssp {} offset {} to process.",
+          pendingEnvelope.envelope.getSystemStreamPartition(), 
pendingEnvelope.envelope.getOffset());
       log.debug("Task {} pending envelopes count is {} after fetching.", 
taskName, queueSize);
 
       if (pendingEnvelope.markProcessed()) {

http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/main/java/org/apache/samza/zk/ZkUtils.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/zk/ZkUtils.java 
b/samza-core/src/main/java/org/apache/samza/zk/ZkUtils.java
index 6655468..0be2b04 100644
--- a/samza-core/src/main/java/org/apache/samza/zk/ZkUtils.java
+++ b/samza-core/src/main/java/org/apache/samza/zk/ZkUtils.java
@@ -157,7 +157,7 @@ public class ZkUtils {
 
   /**
    * subscribe for changes of JobModel version
-   * @param dataListener
+   * @param dataListener describe this
    */
   public void subscribeToJobModelVersionChange(IZkDataListener dataListener) {
     LOG.info("pid=" + processorId + " subscribing for jm version change at:" + 
keyBuilder.getJobModelVersionPath());

http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/main/scala/org/apache/samza/container/RunLoop.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/container/RunLoop.scala 
b/samza-core/src/main/scala/org/apache/samza/container/RunLoop.scala
index 7df7d88..b1ab1e0 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/RunLoop.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/RunLoop.scala
@@ -22,7 +22,6 @@ package org.apache.samza.container
 import org.apache.samza.task.CoordinatorRequests
 import org.apache.samza.system.{IncomingMessageEnvelope, SystemConsumers, 
SystemStreamPartition}
 import org.apache.samza.task.ReadableCoordinator
-import org.apache.samza.task.StreamTask
 import org.apache.samza.util.{Logging, Throttleable, ThrottlingExecutor, 
TimerUtils}
 
 import scala.collection.JavaConversions._
@@ -37,7 +36,7 @@ import scala.collection.JavaConversions._
  * be done when.
  */
 class RunLoop (
-  val taskInstances: Map[TaskName, TaskInstance[StreamTask]],
+  val taskInstances: Map[TaskName, TaskInstance],
   val consumerMultiplexer: SystemConsumers,
   val metrics: SamzaContainerMetrics,
   val maxThrottlingDelayMs: Long,
@@ -57,13 +56,16 @@ class RunLoop (
   // Keep a mapping of SystemStreamPartition to TaskInstance to efficiently 
route them.
   val systemStreamPartitionToTaskInstances = 
getSystemStreamPartitionToTaskInstancesMapping
 
-  def getSystemStreamPartitionToTaskInstancesMapping: 
Map[SystemStreamPartition, List[TaskInstance[StreamTask]]] = {
-    // We could just pass in the SystemStreamPartitionMap during construction, 
but it's safer and cleaner to derive the information directly
-    def getSystemStreamPartitionToTaskInstance(taskInstance: 
TaskInstance[StreamTask]) = taskInstance.systemStreamPartitions.map(_ -> 
taskInstance).toMap
+  def getSystemStreamPartitionToTaskInstancesMapping: 
Map[SystemStreamPartition, List[TaskInstance]] = {
+    // We could just pass in the SystemStreamPartitionMap during construction,
+    // but it's safer and cleaner to derive the information directly
+    def getSystemStreamPartitionToTaskInstance(taskInstance: TaskInstance) =
+      taskInstance.systemStreamPartitions.map(_ -> taskInstance).toMap
 
-    taskInstances.values.map { getSystemStreamPartitionToTaskInstance 
}.flatten.groupBy(_._1).map {
-      case (ssp, ssp2taskInstance) => ssp -> ssp2taskInstance.map(_._2).toList
-    }
+    taskInstances.values
+      .flatMap(getSystemStreamPartitionToTaskInstance)
+      .groupBy(_._1)
+      .map { case (ssp, ssp2taskInstance) => ssp -> 
ssp2taskInstance.map(_._2).toList }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
----------------------------------------------------------------------
diff --git 
a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala 
b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
index f4d605f..3adaeda 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
@@ -22,9 +22,7 @@ package org.apache.samza.container
 import java.io.File
 import java.nio.file.Path
 import java.util
-import java.util.concurrent.ExecutorService
-import java.util.concurrent.Executors
-import java.util.concurrent.TimeUnit
+import java.util.concurrent.{CountDownLatch, ExecutorService, Executors, 
TimeUnit}
 import java.lang.Thread.UncaughtExceptionHandler
 import java.net.{URL, UnknownHostException}
 import org.apache.samza.SamzaException
@@ -69,7 +67,9 @@ import 
org.apache.samza.system.chooser.RoundRobinChooserFactory
 import org.apache.samza.task.AsyncRunLoop
 import org.apache.samza.task.AsyncStreamTask
 import org.apache.samza.task.AsyncStreamTaskAdapter
+import org.apache.samza.task.AsyncStreamTaskFactory
 import org.apache.samza.task.StreamTask
+import org.apache.samza.task.StreamTaskFactory
 import org.apache.samza.task.TaskInstanceCollector
 import org.apache.samza.util.HighResolutionClock
 import org.apache.samza.util.ExponentialSleepStrategy
@@ -165,7 +165,8 @@ object SamzaContainer extends Logging {
     maxChangeLogStreamPartitions: Int,
     localityManager: LocalityManager,
     jmxServer: JmxServer,
-    customReporters: Map[String, MetricsReporter] = Map[String, 
MetricsReporter]()) = {
+    customReporters: Map[String, MetricsReporter] = Map[String, 
MetricsReporter](),
+    taskFactory: Object = null) = {
     val containerName = getSamzaContainerName(containerId)
     val containerPID = Util.getContainerPID
 
@@ -234,12 +235,6 @@ object SamzaContainer extends Logging {
 
     info("Got input stream metadata: %s" format inputStreamMetadata)
 
-    val taskClassName = config
-      .getTaskClass
-      .getOrElse(throw new SamzaException("No task class defined in 
configuration."))
-
-    info("Got stream task class: %s" format taskClassName)
-
     val consumers = inputSystems
       .map(systemName => {
         val systemFactory = systemFactories(systemName)
@@ -248,7 +243,7 @@ object SamzaContainer extends Logging {
           (systemName, systemFactory.getConsumer(systemName, config, 
samzaContainerMetrics.registry))
         } catch {
           case e: Exception =>
-            error("Failed to create a consumer for %s, so skipping." 
format(systemName), e)
+            error("Failed to create a consumer for %s, so skipping." format 
systemName, e)
             (systemName, null)
         }
       })
@@ -257,11 +252,6 @@ object SamzaContainer extends Logging {
 
     info("Got system consumers: %s" format consumers.keys)
 
-    val isAsyncTask = 
classOf[AsyncStreamTask].isAssignableFrom(Class.forName(taskClassName))
-    if (isAsyncTask) {
-      info("%s is AsyncStreamTask" format taskClassName)
-    }
-
     val producers = systemFactories
       .map {
         case (systemName, systemFactory) =>
@@ -269,12 +259,11 @@ object SamzaContainer extends Logging {
             (systemName, systemFactory.getProducer(systemName, config, 
samzaContainerMetrics.registry))
           } catch {
             case e: Exception =>
-              error("Failed to create a producer for %s, so skipping." 
format(systemName), e)
+              error("Failed to create a producer for %s, so skipping." format 
systemName, e)
               (systemName, null)
           }
       }
       .filter(_._2 != null)
-      .toMap
 
     info("Got system producers: %s" format producers.keys)
 
@@ -292,44 +281,48 @@ object SamzaContainer extends Logging {
     info("Got serdes: %s" format serdes.keys)
 
     /*
-     * A Helper function to build a Map[String, Serde] (systemName -> Serde) 
for systems defined in the config. This is useful to build both key and message 
serde maps.
+     * A Helper function to build a Map[String, Serde] (systemName -> Serde) 
for systems defined
+     * in the config. This is useful to build both key and message serde maps.
      */
     val buildSystemSerdeMap = (getSerdeName: (String) => Option[String]) => {
       systemNames
         .filter(systemName => getSerdeName(systemName).isDefined)
               .map(systemName => {
           val serdeName = getSerdeName(systemName).get
-          val serde = serdes.getOrElse(serdeName, throw new 
SamzaException("buildSystemSerdeMap: No class defined for serde: %s." format 
serdeName))
+          val serde = serdes.getOrElse(serdeName,
+            throw new SamzaException("buildSystemSerdeMap: No class defined 
for serde: %s." format serdeName))
           (systemName, serde)
         }).toMap
     }
 
     /*
-     * A Helper function to build a Map[SystemStream, Serde] for streams 
defined in the config. This is useful to build both key and message serde maps.
+     * A Helper function to build a Map[SystemStream, Serde] for streams 
defined in the config.
+     * This is useful to build both key and message serde maps.
      */
     val buildSystemStreamSerdeMap = (getSerdeName: (SystemStream) => 
Option[String]) => {
       (serdeStreams ++ inputSystemStreamPartitions)
         .filter(systemStream => getSerdeName(systemStream).isDefined)
         .map(systemStream => {
           val serdeName = getSerdeName(systemStream).get
-          val serde = serdes.getOrElse(serdeName, throw new 
SamzaException("buildSystemStreamSerdeMap: No class defined for serde: %s." 
format serdeName))
+          val serde = serdes.getOrElse(serdeName,
+            throw new SamzaException("buildSystemStreamSerdeMap: No class 
defined for serde: %s." format serdeName))
           (systemStream, serde)
         }).toMap
     }
 
-    val systemKeySerdes = buildSystemSerdeMap((systemName: String) => 
config.getSystemKeySerde(systemName))
+    val systemKeySerdes = buildSystemSerdeMap(systemName => 
config.getSystemKeySerde(systemName))
 
     debug("Got system key serdes: %s" format systemKeySerdes)
 
-    val systemMessageSerdes = buildSystemSerdeMap((systemName: String) => 
config.getSystemMsgSerde(systemName))
+    val systemMessageSerdes = buildSystemSerdeMap(systemName => 
config.getSystemMsgSerde(systemName))
 
     debug("Got system message serdes: %s" format systemMessageSerdes)
 
-    val systemStreamKeySerdes = buildSystemStreamSerdeMap((systemStream: 
SystemStream) => config.getStreamKeySerde(systemStream))
+    val systemStreamKeySerdes = buildSystemStreamSerdeMap(systemStream => 
config.getStreamKeySerde(systemStream))
 
     debug("Got system stream key serdes: %s" format systemStreamKeySerdes)
 
-    val systemStreamMessageSerdes = buildSystemStreamSerdeMap((systemStream: 
SystemStream) => config.getStreamMsgSerde(systemStream))
+    val systemStreamMessageSerdes = buildSystemStreamSerdeMap(systemStream => 
config.getStreamMsgSerde(systemStream))
 
     debug("Got system stream message serdes: %s" format 
systemStreamMessageSerdes)
 
@@ -378,13 +371,10 @@ object SamzaContainer extends Logging {
 
     val coordinatorSystemProducer = new 
CoordinatorStreamSystemFactory().getCoordinatorStreamSystemProducer(config, 
samzaContainerMetrics.registry)
     val localityManager = new LocalityManager(coordinatorSystemProducer)
-    val checkpointManager = config.getCheckpointManagerFactory() match {
-      case Some(checkpointFactoryClassName) if 
(!checkpointFactoryClassName.isEmpty) =>
-        Util
-          .getObj[CheckpointManagerFactory](checkpointFactoryClassName)
-          .getCheckpointManager(config, samzaContainerMetrics.registry)
-      case _ => null
-    }
+    val checkpointManager = config.getCheckpointManagerFactory()
+      .filterNot(_.isEmpty)
+      
.map(Util.getObj[CheckpointManagerFactory](_).getCheckpointManager(config, 
samzaContainerMetrics.registry))
+      .orNull
     info("Got checkpoint manager: %s" format checkpointManager)
 
     // create a map of consumers with callbacks to pass to the OffsetManager
@@ -442,8 +432,26 @@ object SamzaContainer extends Logging {
     val singleThreadMode = config.getSingleThreadMode
     info("Got single thread mode: " + singleThreadMode)
 
+    val taskClassName = config.getTaskClass.orNull
+    info("Got task class name: %s" format taskClassName)
+
+    if (taskClassName == null && taskFactory == null) {
+      throw new SamzaException("Either the task class name or the task factory 
instance is required.")
+    }
+
+    val isAsyncTask: Boolean =
+      if (taskFactory != null) {
+        taskFactory.isInstanceOf[AsyncStreamTaskFactory]
+      } else {
+        classOf[AsyncStreamTask].isAssignableFrom(Class.forName(taskClassName))
+      }
+
+    if (isAsyncTask) {
+      info("Got an AsyncStreamTask implementation.")
+    }
+
     if(singleThreadMode && isAsyncTask) {
-      throw new SamzaException("AsyncStreamTask %s cannot run on single thread 
mode." format taskClassName)
+      throw new SamzaException("AsyncStreamTask cannot run on single thread 
mode.")
     }
 
     val threadPoolSize = config.getThreadPoolSize
@@ -470,12 +478,23 @@ object SamzaContainer extends Logging {
     val storeWatchPaths = new util.HashSet[Path]()
     storeWatchPaths.add(defaultStoreBaseDir.toPath)
 
-    val taskInstances: Map[TaskName, TaskInstance[_]] = 
containerModel.getTasks.values.map(taskModel => {
+    val taskInstances: Map[TaskName, TaskInstance] = 
containerModel.getTasks.values.map(taskModel => {
       debug("Setting up task instance: %s" format taskModel)
 
       val taskName = taskModel.getTaskName
 
-      val taskObj = Class.forName(taskClassName).newInstance
+      val taskObj = if (taskFactory != null) {
+        debug("Using task factory to create task instance")
+        taskFactory match {
+          case tf: AsyncStreamTaskFactory => tf.createInstance()
+          case tf: StreamTaskFactory => tf.createInstance()
+          case _ =>
+            throw new SamzaException("taskFactory must be an instance of 
StreamTaskFactory or AsyncStreamTaskFactory")
+        }
+      } else {
+        debug("Using task class name: %s to create instance" format 
taskClassName)
+        Class.forName(taskClassName).newInstance
+      }
 
       val task = if (!singleThreadMode && !isAsyncTask)
         // Wrap the StreamTask into a AsyncStreamTask with the build-in thread 
pool
@@ -491,18 +510,21 @@ object SamzaContainer extends Logging {
         .map {
           case (storeName, changeLogSystemStream) =>
             val systemConsumer = systemFactories
-              .getOrElse(changeLogSystemStream.getSystem, throw new 
SamzaException("Changelog system %s for store %s does not exist in the config." 
format (changeLogSystemStream, storeName)))
+              .getOrElse(changeLogSystemStream.getSystem,
+                throw new SamzaException("Changelog system %s for store %s 
does not " +
+                  "exist in the config." format (changeLogSystemStream, 
storeName)))
               .getConsumer(changeLogSystemStream.getSystem, config, 
taskInstanceMetrics.registry)
             samzaContainerMetrics.addStoreRestorationGauge(taskName, storeName)
             (storeName, systemConsumer)
-        }.toMap
+        }
 
       info("Got store consumers: %s" format storeConsumers)
 
       var loggedStorageBaseDir: File = null
       if(System.getenv(ShellCommandConfig.ENV_LOGGED_STORE_BASE_DIR) != null) {
         val jobNameAndId = Util.getJobNameAndId(config)
-        loggedStorageBaseDir = new 
File(System.getenv(ShellCommandConfig.ENV_LOGGED_STORE_BASE_DIR) + 
File.separator + jobNameAndId._1 + "-" + jobNameAndId._2)
+        loggedStorageBaseDir = new 
File(System.getenv(ShellCommandConfig.ENV_LOGGED_STORE_BASE_DIR)
+          + File.separator + jobNameAndId._1 + "-" + jobNameAndId._2)
       } else {
         warn("No override was provided for logged store base directory. This 
disables local state re-use on " +
           "application restart. If you want to enable this feature, set 
LOGGED_STORE_BASE_DIR as an environment " +
@@ -523,11 +545,13 @@ object SamzaContainer extends Logging {
               null
             }
             val keySerde = config.getStorageKeySerde(storeName) match {
-              case Some(keySerde) => serdes.getOrElse(keySerde, throw new 
SamzaException("StorageKeySerde: No class defined for serde: %s." format 
keySerde))
+              case Some(keySerde) => serdes.getOrElse(keySerde,
+                throw new SamzaException("StorageKeySerde: No class defined 
for serde: %s." format keySerde))
               case _ => null
             }
             val msgSerde = config.getStorageMsgSerde(storeName) match {
-              case Some(msgSerde) => serdes.getOrElse(msgSerde, throw new 
SamzaException("StorageMsgSerde: No class defined for serde: %s." format 
msgSerde))
+              case Some(msgSerde) => serdes.getOrElse(msgSerde,
+                throw new SamzaException("StorageMsgSerde: No class defined 
for serde: %s." format msgSerde))
               case _ => null
             }
             val storeBaseDir = if(changeLogSystemStreamPartition != null) {
@@ -568,7 +592,7 @@ object SamzaContainer extends Logging {
 
       info("Retrieved SystemStreamPartitions " + systemStreamPartitions + " 
for " + taskName)
 
-      def createTaskInstance[T] (task: T ): TaskInstance[T] = new 
TaskInstance[T](
+      def createTaskInstance(task: Any): TaskInstance = new TaskInstance(
           task = task,
           taskName = taskName,
           config = config,
@@ -659,7 +683,7 @@ object SamzaContainer extends Logging {
 
 class SamzaContainer(
   containerContext: SamzaContainerContext,
-  taskInstances: Map[TaskName, TaskInstance[_]],
+  taskInstances: Map[TaskName, TaskInstance],
   runLoop: Runnable,
   consumerMultiplexer: SystemConsumers,
   producerMultiplexer: SystemProducers,
@@ -675,6 +699,17 @@ class SamzaContainer(
   taskThreadPool: ExecutorService = null) extends Runnable with Logging {
 
   val shutdownMs = containerContext.config.getShutdownMs.getOrElse(5000L)
+  private val runLoopStartLatch: CountDownLatch = new CountDownLatch(1)
+
+  def awaitStart(timeoutMs: Long): Boolean = {
+    try {
+      runLoopStartLatch.await(timeoutMs, TimeUnit.MILLISECONDS)
+    } catch {
+      case ie: InterruptedException =>
+        error("Interrupted while waiting for runloop to start!", ie)
+        throw ie
+    }
+  }
 
   def run {
     try {
@@ -691,8 +726,9 @@ class SamzaContainer(
       startConsumers
       startSecurityManger
 
-      info("Entering run loop.")
       addShutdownHook
+      runLoopStartLatch.countDown()
+      info("Entering run loop.")
       runLoop.run
     } catch {
       case e: Exception =>
@@ -716,6 +752,13 @@ class SamzaContainer(
     }
   }
 
+  def shutdown() = {
+    runLoop match {
+      case runLoop: RunLoop => runLoop.shutdown
+      case asyncRunLoop: AsyncRunLoop => asyncRunLoop.shutdown()
+    }
+  }
+
   def startDiskSpaceMonitor: Unit = {
     if (diskSpaceMonitor != null) {
       info("Starting disk space monitor")
@@ -773,9 +816,11 @@ class SamzaContainer(
         localityManager.writeContainerToHostMapping(containerContext.id, 
hostInet.getHostName, jmxUrl, jmxTunnelingUrl)
       } catch {
         case uhe: UnknownHostException =>
-          warn("Received UnknownHostException when persisting locality info 
for container %d: %s" format (containerContext.id, uhe.getMessage))  //No-op
+          warn("Received UnknownHostException when persisting locality info 
for container %d: " +
+            "%s" format (containerContext.id, uhe.getMessage))  //No-op
         case unknownException: Throwable =>
-          warn("Received an exception when persisting locality info for 
container %d: %s" format (containerContext.id, unknownException.getMessage))
+          warn("Received an exception when persisting locality info for 
container %d: " +
+            "%s" format (containerContext.id, unknownException.getMessage))
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
----------------------------------------------------------------------
diff --git 
a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala 
b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
index 26a8f5f..e07fcf4 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
@@ -43,8 +43,8 @@ import org.apache.samza.util.Logging
 
 import scala.collection.JavaConversions._
 
-class TaskInstance[T](
-  task: T,
+class TaskInstance(
+  task: Any,
   val taskName: TaskName,
   config: Config,
   val metrics: TaskInstanceMetrics,
@@ -84,7 +84,8 @@ class TaskInstance[T](
   // store the (ssp -> if this ssp is catched up) mapping. "catched up"
   // means the same ssp in other taskInstances have the same offset as
   // the one here.
-  var ssp2catchedupMapping: 
scala.collection.mutable.Map[SystemStreamPartition, Boolean] = 
scala.collection.mutable.Map[SystemStreamPartition, Boolean]()
+  var ssp2catchedupMapping: 
scala.collection.mutable.Map[SystemStreamPartition, Boolean] =
+    scala.collection.mutable.Map[SystemStreamPartition, Boolean]()
   systemStreamPartitions.foreach(ssp2catchedupMapping += _ -> false)
 
   def registerMetrics {
@@ -140,7 +141,8 @@ class TaskInstance[T](
     })
   }
 
-  def process(envelope: IncomingMessageEnvelope, coordinator: 
ReadableCoordinator, callbackFactory: TaskCallbackFactory = null) {
+  def process(envelope: IncomingMessageEnvelope, coordinator: 
ReadableCoordinator,
+    callbackFactory: TaskCallbackFactory = null) {
     metrics.processes.inc
 
     if (!ssp2catchedupMapping.getOrElse(envelope.getSystemStreamPartition,
@@ -151,7 +153,8 @@ class TaskInstance[T](
     if (ssp2catchedupMapping(envelope.getSystemStreamPartition)) {
       metrics.messagesActuallyProcessed.inc
 
-      trace("Processing incoming message envelope for taskName and SSP: %s, 
%s" format (taskName, envelope.getSystemStreamPartition))
+      trace("Processing incoming message envelope for taskName and SSP: %s, %s"
+        format (taskName, envelope.getSystemStreamPartition))
 
       if (isAsyncTask) {
         exceptionHandler.maybeHandle {
@@ -163,7 +166,8 @@ class TaskInstance[T](
          task.asInstanceOf[StreamTask].process(envelope, collector, 
coordinator)
         }
 
-        trace("Updating offset map for taskName, SSP and offset: %s, %s, %s" 
format (taskName, envelope.getSystemStreamPartition, envelope.getOffset))
+        trace("Updating offset map for taskName, SSP and offset: %s, %s, %s"
+          format (taskName, envelope.getSystemStreamPartition, 
envelope.getOffset))
 
         offsetManager.update(taskName, envelope.getSystemStreamPartition, 
envelope.getOffset)
       }
@@ -173,7 +177,7 @@ class TaskInstance[T](
   def endOfStream(coordinator: ReadableCoordinator): Unit = {
     if (isEndOfStreamListenerTask) {
       exceptionHandler.maybeHandle {
-        task.asInstanceOf[EndOfStreamListenerTask].onEndOfStream(collector, 
coordinator);
+        task.asInstanceOf[EndOfStreamListenerTask].onEndOfStream(collector, 
coordinator)
       }
     }
   }
@@ -230,7 +234,8 @@ class TaskInstance[T](
 
   override def toString() = "TaskInstance for class %s and taskName %s." 
format (task.getClass.getName, taskName)
 
-  def toDetailedString() = "TaskInstance [taskName = %s, windowable=%s, 
closable=%s endofstreamlistener=%s]" format (taskName, isWindowableTask, 
isClosableTask, isEndOfStreamListenerTask)
+  def toDetailedString() = "TaskInstance [taskName = %s, windowable=%s, 
closable=%s endofstreamlistener=%s]" format
+    (taskName, isWindowableTask, isClosableTask, isEndOfStreamListenerTask)
 
   /**
    * From the envelope, check if this SSP has catched up with the starting 
offset of the SSP

http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
----------------------------------------------------------------------
diff --git 
a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala 
b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
index 9ccf6fc..0d5adee 100644
--- 
a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
+++ 
b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
@@ -49,22 +49,14 @@ class ThreadJobFactory extends StreamJobFactory with 
Logging {
 
     try {
       coordinator.start
-      new ThreadJob(new Runnable {
-        override def run(): Unit = {
-          val jmxServer = new JmxServer
-          try {
+      new ThreadJob(
             SamzaContainer(
               containerModel.getContainerId,
               containerModel,
               config,
               jobModel.maxChangeLogStreamPartitions,
               null,
-              new JmxServer)
-          } finally {
-            jmxServer.stop
-          }
-        }
-      })
+              new JmxServer))
     } finally {
       coordinator.stop
     }

http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java 
b/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java
index 58975d3..798977b 100644
--- a/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java
+++ b/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java
@@ -60,7 +60,7 @@ import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 public class TestAsyncRunLoop {
-  Map<TaskName, TaskInstance<AsyncStreamTask>> tasks;
+  Map<TaskName, TaskInstance> tasks;
   ExecutorService executor;
   SystemConsumers consumerMultiplexer;
   SamzaContainerMetrics containerMetrics;
@@ -87,8 +87,8 @@ public class TestAsyncRunLoop {
 
   TestTask task0;
   TestTask task1;
-  TaskInstance<AsyncStreamTask> t0;
-  TaskInstance<AsyncStreamTask> t1;
+  TaskInstance t0;
+  TaskInstance t1;
 
   AsyncRunLoop createRunLoop() {
     return new AsyncRunLoop(tasks,
@@ -103,15 +103,15 @@ public class TestAsyncRunLoop {
         () -> 0L);
   }
 
-  TaskInstance<AsyncStreamTask> createTaskInstance(AsyncStreamTask task, 
TaskName taskName, SystemStreamPartition ssp, OffsetManager manager, 
SystemConsumers consumers) {
+  TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, 
SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) {
     TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", 
new MetricsRegistryMap());
     scala.collection.immutable.Set<SystemStreamPartition> sspSet = 
JavaConversions.asScalaSet(Collections.singleton(ssp)).toSet();
-    return new TaskInstance<AsyncStreamTask>(task, taskName, 
mock(Config.class), taskInstanceMetrics,
+    return new TaskInstance(task, taskName, mock(Config.class), 
taskInstanceMetrics,
         null, consumers, mock(TaskInstanceCollector.class), 
mock(SamzaContainerContext.class),
         manager, null, null, sspSet, new 
TaskInstanceExceptionHandler(taskInstanceMetrics, new 
scala.collection.immutable.HashSet<String>()));
   }
 
-  TaskInstance<AsyncStreamTask> createTaskInstance(AsyncStreamTask task, 
TaskName taskName, SystemStreamPartition ssp) {
+  TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, 
SystemStreamPartition ssp) {
     return createTaskInstance(task, taskName, ssp, offsetManager, 
consumerMultiplexer);
   }
 
@@ -466,7 +466,7 @@ public class TestAsyncRunLoop {
     sspMap.put(ssp2, messageList);
 
     SystemConsumer mockConsumer = mock(SystemConsumer.class);
-    when(mockConsumer.poll((Set<SystemStreamPartition>) anyObject(), 
anyLong())).thenReturn(sspMap);
+    when(mockConsumer.poll(anyObject(), anyLong())).thenReturn(sspMap);
 
     HashMap<String, SystemConsumer> systemConsumerMap = new HashMap<>();
     systemConsumerMap.put("system1", mockConsumer);
@@ -485,9 +485,9 @@ public class TestAsyncRunLoop {
     when(offsetManager.getStartingOffset(taskName1, 
ssp1)).thenReturn(Option.apply(IncomingMessageEnvelope.END_OF_STREAM_OFFSET));
     when(offsetManager.getStartingOffset(taskName2, 
ssp2)).thenReturn(Option.apply("1"));
 
-    TaskInstance<AsyncStreamTask> taskInstance1 = 
createTaskInstance(mockStreamTask1, taskName1, ssp1, offsetManager, consumers);
-    TaskInstance<AsyncStreamTask> taskInstance2 = 
createTaskInstance(mockStreamTask2, taskName2, ssp2, offsetManager, consumers);
-    Map<TaskName, TaskInstance<AsyncStreamTask>> tasks = new HashMap<>();
+    TaskInstance taskInstance1 = createTaskInstance(mockStreamTask1, 
taskName1, ssp1, offsetManager, consumers);
+    TaskInstance taskInstance2 = createTaskInstance(mockStreamTask2, 
taskName2, ssp2, offsetManager, consumers);
+    Map<TaskName, TaskInstance> tasks = new HashMap<>();
     tasks.put(taskName1, taskInstance1);
     tasks.put(taskName2, taskInstance2);
 

http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala
----------------------------------------------------------------------
diff --git 
a/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala 
b/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala
index d83c7e2..0304fc8 100644
--- a/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala
+++ b/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala
@@ -31,7 +31,6 @@ import org.apache.samza.system.SystemConsumers
 import org.apache.samza.system.SystemStreamPartition
 import org.apache.samza.task.TaskCoordinator.RequestScope
 import org.apache.samza.task.ReadableCoordinator
-import org.apache.samza.task.StreamTask
 import org.apache.samza.util.Clock
 import org.junit.Assert._
 import org.junit.Test
@@ -55,12 +54,12 @@ class TestRunLoop extends AssertionsForJUnit with 
MockitoSugar with ScalaTestMat
   val envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0")
   val envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1")
 
-  def getMockTaskInstances: Map[TaskName, TaskInstance[StreamTask]] = {
-    val ti0 = mock[TaskInstance[StreamTask]]
+  def getMockTaskInstances: Map[TaskName, TaskInstance] = {
+    val ti0 = mock[TaskInstance]
     when(ti0.systemStreamPartitions).thenReturn(Set(ssp0))
     when(ti0.taskName).thenReturn(taskName0)
 
-    val ti1 = mock[TaskInstance[StreamTask]]
+    val ti1 = mock[TaskInstance]
     when(ti1.systemStreamPartitions).thenReturn(Set(ssp1))
     when(ti1.taskName).thenReturn(taskName1)
 
@@ -183,7 +182,7 @@ class TestRunLoop extends AssertionsForJUnit with 
MockitoSugar with ScalaTestMat
   def anyObject[T] = Matchers.anyObject.asInstanceOf[T]
 
   // Stub out TaskInstance.process. Mockito really doesn't make this easy. :(
-  def stubProcess(taskInstance: TaskInstance[StreamTask], process: 
(IncomingMessageEnvelope, ReadableCoordinator) => Unit) {
+  def stubProcess(taskInstance: TaskInstance, process: 
(IncomingMessageEnvelope, ReadableCoordinator) => Unit) {
     when(taskInstance.process(anyObject, anyObject, anyObject)).thenAnswer(new 
Answer[Unit]() {
       override def answer(invocation: InvocationOnMock) {
         val envelope = 
invocation.getArguments()(0).asInstanceOf[IncomingMessageEnvelope]
@@ -276,9 +275,9 @@ class TestRunLoop extends AssertionsForJUnit with 
MockitoSugar with ScalaTestMat
 
   @Test
   def testGetSystemStreamPartitionToTaskInstancesMapping {
-    val ti0 = mock[TaskInstance[StreamTask]]
-    val ti1 = mock[TaskInstance[StreamTask]]
-    val ti2 = mock[TaskInstance[StreamTask]]
+    val ti0 = mock[TaskInstance]
+    val ti1 = mock[TaskInstance]
+    val ti2 = mock[TaskInstance]
     when(ti0.systemStreamPartitions).thenReturn(Set(ssp0))
     when(ti1.systemStreamPartitions).thenReturn(Set(ssp1))
     when(ti2.systemStreamPartitions).thenReturn(Set(ssp1))

http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
----------------------------------------------------------------------
diff --git 
a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala 
b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
index 5895037..07055a0 100644
--- 
a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
+++ 
b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
@@ -180,7 +180,7 @@ class TestSamzaContainer extends AssertionsForJUnit with 
MockitoSugar {
       new SerdeManager)
     val collector = new TaskInstanceCollector(producerMultiplexer)
     val containerContext = new SamzaContainerContext(0, config, 
Set[TaskName](taskName))
-    val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
+    val taskInstance: TaskInstance = new TaskInstance(
       task,
       taskName,
       config,
@@ -262,7 +262,7 @@ class TestSamzaContainer extends AssertionsForJUnit with 
MockitoSugar {
       }
     })
 
-    val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
+    val taskInstance: TaskInstance = new TaskInstance(
       task,
       taskName,
       config,

http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
----------------------------------------------------------------------
diff --git 
a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala 
b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
index 3c83529..7e35525 100644
--- 
a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
+++ 
b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
@@ -71,7 +71,7 @@ class TestTaskInstance {
     val taskName = new TaskName("taskName")
     val collector = new TaskInstanceCollector(producerMultiplexer)
     val containerContext = new SamzaContainerContext(0, config, 
Set[TaskName](taskName))
-    val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
+    val taskInstance: TaskInstance = new TaskInstance(
       task,
       taskName,
       config,
@@ -169,7 +169,7 @@ class TestTaskInstance {
 
     val registry = new MetricsRegistryMap
     val taskMetrics = new TaskInstanceMetrics(registry = registry)
-    val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
+    val taskInstance = new TaskInstance(
       task,
       taskName,
       config,
@@ -226,7 +226,7 @@ class TestTaskInstance {
 
     val registry = new MetricsRegistryMap
     val taskMetrics = new TaskInstanceMetrics(registry = registry)
-    val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
+    val taskInstance = new TaskInstance(
       task,
       taskName,
       config,

Reply via email to