Merging TaskFactory changes
Project: http://git-wip-us.apache.org/repos/asf/samza/repo Commit: http://git-wip-us.apache.org/repos/asf/samza/commit/be6be8cb Tree: http://git-wip-us.apache.org/repos/asf/samza/tree/be6be8cb Diff: http://git-wip-us.apache.org/repos/asf/samza/diff/be6be8cb Branch: refs/heads/samza-standalone Commit: be6be8cba1a28a8395ca9c6700eb125093f6c517 Parents: 7d6332b Author: navina <[email protected]> Authored: Fri Dec 23 16:59:07 2016 -0800 Committer: navina <[email protected]> Committed: Fri Dec 23 16:59:07 2016 -0800 ---------------------------------------------------------------------- .../samza/task/AsyncStreamTaskFactory.java | 28 ++++ .../apache/samza/task/StreamTaskFactory.java | 28 ++++ .../apache/samza/container/RunLoopFactory.java | 20 ++- .../coordinator/JobCoordinatorFactory.java | 3 + .../apache/samza/processor/StreamProcessor.java | 27 +++- .../org/apache/samza/task/AsyncRunLoop.java | 37 ++--- .../main/java/org/apache/samza/zk/ZkUtils.java | 2 +- .../org/apache/samza/container/RunLoop.scala | 18 +-- .../apache/samza/container/SamzaContainer.scala | 137 ++++++++++++------- .../apache/samza/container/TaskInstance.scala | 21 +-- .../samza/job/local/ThreadJobFactory.scala | 12 +- .../org/apache/samza/task/TestAsyncRunLoop.java | 20 +-- .../apache/samza/container/TestRunLoop.scala | 15 +- .../samza/container/TestSamzaContainer.scala | 4 +- .../samza/container/TestTaskInstance.scala | 6 +- 15 files changed, 251 insertions(+), 127 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-api/src/main/java/org/apache/samza/task/AsyncStreamTaskFactory.java ---------------------------------------------------------------------- diff --git a/samza-api/src/main/java/org/apache/samza/task/AsyncStreamTaskFactory.java b/samza-api/src/main/java/org/apache/samza/task/AsyncStreamTaskFactory.java new file mode 100644 index 0000000..e5ce9c4 --- /dev/null +++ b/samza-api/src/main/java/org/apache/samza/task/AsyncStreamTaskFactory.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.samza.task; + +/** + * Build {@link AsyncStreamTask} instances. + * Implementations should return a new instance for each {@link #createInstance()} invocation. + */ +public interface AsyncStreamTaskFactory { + AsyncStreamTask createInstance(); +} http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-api/src/main/java/org/apache/samza/task/StreamTaskFactory.java ---------------------------------------------------------------------- diff --git a/samza-api/src/main/java/org/apache/samza/task/StreamTaskFactory.java b/samza-api/src/main/java/org/apache/samza/task/StreamTaskFactory.java new file mode 100644 index 0000000..ec53bc0 --- /dev/null +++ b/samza-api/src/main/java/org/apache/samza/task/StreamTaskFactory.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.samza.task; + +/** + * Build {@link StreamTask} instances. + * Implementations should return a new instance for each {@link #createInstance()} invocation. + */ +public interface StreamTaskFactory { + StreamTask createInstance(); +} http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java b/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java index 23a68cb..32ab47a 100644 --- a/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java +++ b/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java @@ -19,20 +19,19 @@ package org.apache.samza.container; -import java.util.concurrent.ExecutorService; import org.apache.samza.SamzaException; import org.apache.samza.config.TaskConfig; -import org.apache.samza.util.HighResolutionClock; import org.apache.samza.system.SystemConsumers; import org.apache.samza.task.AsyncRunLoop; -import org.apache.samza.task.AsyncStreamTask; -import org.apache.samza.task.StreamTask; +import org.apache.samza.util.HighResolutionClock; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.collection.JavaConversions; import scala.runtime.AbstractFunction0; import scala.runtime.AbstractFunction1; +import java.util.concurrent.ExecutorService; + import static org.apache.samza.util.Util.asScalaClock; /** @@ -46,7 +45,7 @@ public class RunLoopFactory { private static final long DEFAULT_COMMIT_MS = 60000L; private static final long DEFAULT_CALLBACK_TIMEOUT_MS = -1L; - public static Runnable createRunLoop(scala.collection.immutable.Map<TaskName, TaskInstance<?>> taskInstances, + public static Runnable createRunLoop(scala.collection.immutable.Map<TaskName, TaskInstance> taskInstances, SystemConsumers consumerMultiplexer, ExecutorService threadPool, long maxThrottlingDelayMs, @@ -62,9 +61,9 @@ public class RunLoopFactory { log.info("Got commit milliseconds: " + taskCommitMs); - int asyncTaskCount = taskInstances.values().count(new AbstractFunction1<TaskInstance<?>, Object>() { + int asyncTaskCount = taskInstances.values().count(new AbstractFunction1<TaskInstance, Object>() { @Override - public Boolean apply(TaskInstance<?> t) { + public Boolean apply(TaskInstance t) { return t.isAsyncTask(); } }); @@ -77,9 +76,8 @@ public class RunLoopFactory { if (asyncTaskCount == 0) { log.info("Run loop in single thread mode."); - scala.collection.immutable.Map<TaskName, TaskInstance<StreamTask>> streamTaskInstances = (scala.collection.immutable.Map) taskInstances; return new RunLoop( - streamTaskInstances, + taskInstances, consumerMultiplexer, containerMetrics, maxThrottlingDelayMs, @@ -95,12 +93,10 @@ public class RunLoopFactory { log.info("Got callback timeout: " + callbackTimeout); - scala.collection.immutable.Map<TaskName, TaskInstance<AsyncStreamTask>> asyncStreamTaskInstances = (scala.collection.immutable.Map) taskInstances; - log.info("Run loop in asynchronous mode."); return new AsyncRunLoop( - JavaConversions.mapAsJavaMap(asyncStreamTaskInstances), + JavaConversions.mapAsJavaMap(taskInstances), threadPool, consumerMultiplexer, taskMaxConcurrency, http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/main/java/org/apache/samza/coordinator/JobCoordinatorFactory.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/coordinator/JobCoordinatorFactory.java b/samza-core/src/main/java/org/apache/samza/coordinator/JobCoordinatorFactory.java index af2aaa7..e12e49f 100644 --- a/samza-core/src/main/java/org/apache/samza/coordinator/JobCoordinatorFactory.java +++ b/samza-core/src/main/java/org/apache/samza/coordinator/JobCoordinatorFactory.java @@ -23,7 +23,10 @@ import org.apache.samza.processor.SamzaContainerController; public interface JobCoordinatorFactory { /** + * @param processorId Unique identifier for the processor * @param config Configs relevant for the JobCoordinator TODO: Separate JC related configs into a "JobCoordinatorConfig" + * @param containerController Controller interface for starting and stopping container. In future, it may simply + * pause the container and add/remove tasks * @return An instance of IJobCoordinator */ JobCoordinator getJobCoordinator(int processorId, Config config, SamzaContainerController containerController); http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java b/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java index 0f34400..61795e1 100644 --- a/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java +++ b/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java @@ -25,6 +25,8 @@ import org.apache.samza.config.TaskConfigJava; import org.apache.samza.coordinator.JobCoordinator; import org.apache.samza.coordinator.JobCoordinatorFactory; import org.apache.samza.metrics.MetricsReporter; +import org.apache.samza.task.AsyncStreamTaskFactory; +import org.apache.samza.task.StreamTaskFactory; import org.apache.samza.util.Util; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -34,11 +36,11 @@ import java.util.Map; /** * StreamProcessor can be embedded in any application or executed in a distributed environment (aka cluster) as - * independent processes <br /> + * independent processes * <p> * <b>Usage Example:</b> * <pre> - * StreamProcessor processor = new StreamProcessor(1, config); <br /> + * StreamProcessor processor = new StreamProcessor(1, config); * processor.start(); * try { * boolean status = processor.awaitStart(TIMEOUT_MS); // Optional - blocking call @@ -74,7 +76,7 @@ public class StreamProcessor { * JobCoordinator controls how the various StreamProcessor instances belonging to a job coordinate. It is also * responsible generating and updating JobModel. * When StreamProcessor starts, it starts the JobCoordinator and brings up a SamzaContainer based on the JobModel. - * SamzaContainer is executed using an ExecutorService. <br /> + * SamzaContainer is executed using an ExecutorService. * <p> * <b>Note:</b> Lifecycle of the ExecutorService is fully managed by the StreamProcessor, and NOT exposed to the user * @@ -82,6 +84,25 @@ public class StreamProcessor { * "containerId" in Samza * @param config Instance of config object - contains all configuration required for processing * @param customMetricsReporters Map of custom MetricReporter instances that are to be injected in the Samza job + * @param asyncStreamTaskFactory The {@link AsyncStreamTaskFactory} to be used for creating task instances. + */ + public StreamProcessor(int processorId, Config config, Map<String, MetricsReporter> customMetricsReporters, + AsyncStreamTaskFactory asyncStreamTaskFactory) { + this(processorId, config, customMetricsReporters, (Object) asyncStreamTaskFactory); + } + + /** + * Same as {@link #StreamProcessor(int, Config, Map, AsyncStreamTaskFactory)}, except task instances are created + * using the provided {@link StreamTaskFactory}. + */ + public StreamProcessor(int processorId, Config config, Map<String, MetricsReporter> customMetricsReporters, + StreamTaskFactory streamTaskFactory) { + this(processorId, config, customMetricsReporters, (Object) streamTaskFactory); + } + + /** + * Same as {@link #StreamProcessor(int, Config, Map, AsyncStreamTaskFactory)}, except task instances are created + * using the "task.class" configuration instead of a task factory. */ public StreamProcessor(int processorId, Config config, Map<String, MetricsReporter> customMetricsReporters) { this(processorId, config, customMetricsReporters, (Object) null); http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/main/java/org/apache/samza/task/AsyncRunLoop.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/task/AsyncRunLoop.java b/samza-core/src/main/java/org/apache/samza/task/AsyncRunLoop.java index 16f1563..ab5514c 100644 --- a/samza-core/src/main/java/org/apache/samza/task/AsyncRunLoop.java +++ b/samza-core/src/main/java/org/apache/samza/task/AsyncRunLoop.java @@ -75,7 +75,7 @@ public class AsyncRunLoop implements Runnable, Throttleable { private volatile Throwable throwable = null; private final HighResolutionClock clock; - public AsyncRunLoop(Map<TaskName, TaskInstance<AsyncStreamTask>> taskInstances, + public AsyncRunLoop(Map<TaskName, TaskInstance> taskInstances, ExecutorService threadPool, SystemConsumers consumerMultiplexer, int maxConcurrency, @@ -100,7 +100,7 @@ public class AsyncRunLoop implements Runnable, Throttleable { this.workerTimer = Executors.newSingleThreadScheduledExecutor(); this.clock = clock; Map<TaskName, AsyncTaskWorker> workers = new HashMap<>(); - for (TaskInstance<AsyncStreamTask> task : taskInstances.values()) { + for (TaskInstance task : taskInstances.values()) { workers.put(task.taskName(), new AsyncTaskWorker(task)); } // Partions and tasks assigned to the container will not change during the run loop life time @@ -112,14 +112,12 @@ public class AsyncRunLoop implements Runnable, Throttleable { * Returns mapping of the SystemStreamPartition to the AsyncTaskWorkers to efficiently route the envelopes */ private static Map<SystemStreamPartition, List<AsyncTaskWorker>> getSspToAsyncTaskWorkerMap( - Map<TaskName, TaskInstance<AsyncStreamTask>> taskInstances, Map<TaskName, AsyncTaskWorker> taskWorkers) { + Map<TaskName, TaskInstance> taskInstances, Map<TaskName, AsyncTaskWorker> taskWorkers) { Map<SystemStreamPartition, List<AsyncTaskWorker>> sspToWorkerMap = new HashMap<>(); - for (TaskInstance<AsyncStreamTask> task : taskInstances.values()) { + for (TaskInstance task : taskInstances.values()) { Set<SystemStreamPartition> ssps = JavaConversions.setAsJavaSet(task.systemStreamPartitions()); for (SystemStreamPartition ssp : ssps) { - if (sspToWorkerMap.get(ssp) == null) { - sspToWorkerMap.put(ssp, new ArrayList<AsyncTaskWorker>()); - } + sspToWorkerMap.putIfAbsent(ssp, new ArrayList<>()); sspToWorkerMap.get(ssp).add(taskWorkers.get(task.taskName())); } } @@ -202,7 +200,8 @@ public class AsyncRunLoop implements Runnable, Throttleable { private IncomingMessageEnvelope chooseEnvelope() { IncomingMessageEnvelope envelope = consumerMultiplexer.choose(false); if (envelope != null) { - log.trace("Choose envelope ssp {} offset {} for processing", envelope.getSystemStreamPartition(), envelope.getOffset()); + log.trace("Choose envelope ssp {} offset {} for processing", + envelope.getSystemStreamPartition(), envelope.getOffset()); containerMetrics.envelopes().inc(); } else { log.trace("No envelope is available"); @@ -310,12 +309,12 @@ public class AsyncRunLoop implements Runnable, Throttleable { * will run the task asynchronously. It runs window and commit in the provided thread pool. */ private class AsyncTaskWorker implements TaskCallbackListener { - private final TaskInstance<AsyncStreamTask> task; + private final TaskInstance task; private final TaskCallbackManager callbackManager; private volatile AsyncTaskState state; - AsyncTaskWorker(TaskInstance<AsyncStreamTask> task) { + AsyncTaskWorker(TaskInstance task) { this.task = task; this.callbackManager = new TaskCallbackManager(this, callbackTimer, callbackTimeoutMs, maxConcurrency, clock); Set<SystemStreamPartition> sspSet = getWorkingSSPSet(task); @@ -352,12 +351,14 @@ public class AsyncRunLoop implements Runnable, Throttleable { * @param task * @return a Set of SSPs such that all SSPs are not at end of stream. */ - private Set<SystemStreamPartition> getWorkingSSPSet(TaskInstance<AsyncStreamTask> task) { + private Set<SystemStreamPartition> getWorkingSSPSet(TaskInstance task) { Set<SystemStreamPartition> allPartitions = new HashSet<>(JavaConversions.setAsJavaSet(task.systemStreamPartitions())); // filter only those SSPs that are not at end of stream. - Set<SystemStreamPartition> workingSSPSet = allPartitions.stream().filter(ssp -> !consumerMultiplexer.isEndOfStream(ssp)).collect(Collectors.toSet()); + Set<SystemStreamPartition> workingSSPSet = allPartitions.stream() + .filter(ssp -> !consumerMultiplexer.isEndOfStream(ssp)) + .collect(Collectors.toSet()); return workingSSPSet; } @@ -512,15 +513,18 @@ public class AsyncRunLoop implements Runnable, Throttleable { state.doneProcess(); TaskCallbackImpl callbackImpl = (TaskCallbackImpl) callback; containerMetrics.processNs().update(clock.nanoTime() - callbackImpl.timeCreatedNs); - log.trace("Got callback complete for task {}, ssp {}", callbackImpl.taskName, callbackImpl.envelope.getSystemStreamPartition()); + log.trace("Got callback complete for task {}, ssp {}", + callbackImpl.taskName, callbackImpl.envelope.getSystemStreamPartition()); TaskCallbackImpl callbackToUpdate = callbackManager.updateCallback(callbackImpl); if (callbackToUpdate != null) { IncomingMessageEnvelope envelope = callbackToUpdate.envelope; - log.trace("Update offset for ssp {}, offset {}", envelope.getSystemStreamPartition(), envelope.getOffset()); + log.trace("Update offset for ssp {}, offset {}", + envelope.getSystemStreamPartition(), envelope.getOffset()); // update offset - task.offsetManager().update(task.taskName(), envelope.getSystemStreamPartition(), envelope.getOffset()); + task.offsetManager().update(task.taskName(), + envelope.getSystemStreamPartition(), envelope.getOffset()); // update coordinator coordinatorRequests.update(callbackToUpdate.coordinator); @@ -697,7 +701,8 @@ public class AsyncRunLoop implements Runnable, Throttleable { PendingEnvelope pendingEnvelope = pendingEnvelopeQueue.remove(); int queueSize = pendingEnvelopeQueue.size(); taskMetrics.pendingMessages().set(queueSize); - log.trace("fetch envelope ssp {} offset {} to process.", pendingEnvelope.envelope.getSystemStreamPartition(), pendingEnvelope.envelope.getOffset()); + log.trace("fetch envelope ssp {} offset {} to process.", + pendingEnvelope.envelope.getSystemStreamPartition(), pendingEnvelope.envelope.getOffset()); log.debug("Task {} pending envelopes count is {} after fetching.", taskName, queueSize); if (pendingEnvelope.markProcessed()) { http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/main/java/org/apache/samza/zk/ZkUtils.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/zk/ZkUtils.java b/samza-core/src/main/java/org/apache/samza/zk/ZkUtils.java index 6655468..0be2b04 100644 --- a/samza-core/src/main/java/org/apache/samza/zk/ZkUtils.java +++ b/samza-core/src/main/java/org/apache/samza/zk/ZkUtils.java @@ -157,7 +157,7 @@ public class ZkUtils { /** * subscribe for changes of JobModel version - * @param dataListener + * @param dataListener describe this */ public void subscribeToJobModelVersionChange(IZkDataListener dataListener) { LOG.info("pid=" + processorId + " subscribing for jm version change at:" + keyBuilder.getJobModelVersionPath()); http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/main/scala/org/apache/samza/container/RunLoop.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/main/scala/org/apache/samza/container/RunLoop.scala b/samza-core/src/main/scala/org/apache/samza/container/RunLoop.scala index 7df7d88..b1ab1e0 100644 --- a/samza-core/src/main/scala/org/apache/samza/container/RunLoop.scala +++ b/samza-core/src/main/scala/org/apache/samza/container/RunLoop.scala @@ -22,7 +22,6 @@ package org.apache.samza.container import org.apache.samza.task.CoordinatorRequests import org.apache.samza.system.{IncomingMessageEnvelope, SystemConsumers, SystemStreamPartition} import org.apache.samza.task.ReadableCoordinator -import org.apache.samza.task.StreamTask import org.apache.samza.util.{Logging, Throttleable, ThrottlingExecutor, TimerUtils} import scala.collection.JavaConversions._ @@ -37,7 +36,7 @@ import scala.collection.JavaConversions._ * be done when. */ class RunLoop ( - val taskInstances: Map[TaskName, TaskInstance[StreamTask]], + val taskInstances: Map[TaskName, TaskInstance], val consumerMultiplexer: SystemConsumers, val metrics: SamzaContainerMetrics, val maxThrottlingDelayMs: Long, @@ -57,13 +56,16 @@ class RunLoop ( // Keep a mapping of SystemStreamPartition to TaskInstance to efficiently route them. val systemStreamPartitionToTaskInstances = getSystemStreamPartitionToTaskInstancesMapping - def getSystemStreamPartitionToTaskInstancesMapping: Map[SystemStreamPartition, List[TaskInstance[StreamTask]]] = { - // We could just pass in the SystemStreamPartitionMap during construction, but it's safer and cleaner to derive the information directly - def getSystemStreamPartitionToTaskInstance(taskInstance: TaskInstance[StreamTask]) = taskInstance.systemStreamPartitions.map(_ -> taskInstance).toMap + def getSystemStreamPartitionToTaskInstancesMapping: Map[SystemStreamPartition, List[TaskInstance]] = { + // We could just pass in the SystemStreamPartitionMap during construction, + // but it's safer and cleaner to derive the information directly + def getSystemStreamPartitionToTaskInstance(taskInstance: TaskInstance) = + taskInstance.systemStreamPartitions.map(_ -> taskInstance).toMap - taskInstances.values.map { getSystemStreamPartitionToTaskInstance }.flatten.groupBy(_._1).map { - case (ssp, ssp2taskInstance) => ssp -> ssp2taskInstance.map(_._2).toList - } + taskInstances.values + .flatMap(getSystemStreamPartitionToTaskInstance) + .groupBy(_._1) + .map { case (ssp, ssp2taskInstance) => ssp -> ssp2taskInstance.map(_._2).toList } } /** http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala index f4d605f..3adaeda 100644 --- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala +++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala @@ -22,9 +22,7 @@ package org.apache.samza.container import java.io.File import java.nio.file.Path import java.util -import java.util.concurrent.ExecutorService -import java.util.concurrent.Executors -import java.util.concurrent.TimeUnit +import java.util.concurrent.{CountDownLatch, ExecutorService, Executors, TimeUnit} import java.lang.Thread.UncaughtExceptionHandler import java.net.{URL, UnknownHostException} import org.apache.samza.SamzaException @@ -69,7 +67,9 @@ import org.apache.samza.system.chooser.RoundRobinChooserFactory import org.apache.samza.task.AsyncRunLoop import org.apache.samza.task.AsyncStreamTask import org.apache.samza.task.AsyncStreamTaskAdapter +import org.apache.samza.task.AsyncStreamTaskFactory import org.apache.samza.task.StreamTask +import org.apache.samza.task.StreamTaskFactory import org.apache.samza.task.TaskInstanceCollector import org.apache.samza.util.HighResolutionClock import org.apache.samza.util.ExponentialSleepStrategy @@ -165,7 +165,8 @@ object SamzaContainer extends Logging { maxChangeLogStreamPartitions: Int, localityManager: LocalityManager, jmxServer: JmxServer, - customReporters: Map[String, MetricsReporter] = Map[String, MetricsReporter]()) = { + customReporters: Map[String, MetricsReporter] = Map[String, MetricsReporter](), + taskFactory: Object = null) = { val containerName = getSamzaContainerName(containerId) val containerPID = Util.getContainerPID @@ -234,12 +235,6 @@ object SamzaContainer extends Logging { info("Got input stream metadata: %s" format inputStreamMetadata) - val taskClassName = config - .getTaskClass - .getOrElse(throw new SamzaException("No task class defined in configuration.")) - - info("Got stream task class: %s" format taskClassName) - val consumers = inputSystems .map(systemName => { val systemFactory = systemFactories(systemName) @@ -248,7 +243,7 @@ object SamzaContainer extends Logging { (systemName, systemFactory.getConsumer(systemName, config, samzaContainerMetrics.registry)) } catch { case e: Exception => - error("Failed to create a consumer for %s, so skipping." format(systemName), e) + error("Failed to create a consumer for %s, so skipping." format systemName, e) (systemName, null) } }) @@ -257,11 +252,6 @@ object SamzaContainer extends Logging { info("Got system consumers: %s" format consumers.keys) - val isAsyncTask = classOf[AsyncStreamTask].isAssignableFrom(Class.forName(taskClassName)) - if (isAsyncTask) { - info("%s is AsyncStreamTask" format taskClassName) - } - val producers = systemFactories .map { case (systemName, systemFactory) => @@ -269,12 +259,11 @@ object SamzaContainer extends Logging { (systemName, systemFactory.getProducer(systemName, config, samzaContainerMetrics.registry)) } catch { case e: Exception => - error("Failed to create a producer for %s, so skipping." format(systemName), e) + error("Failed to create a producer for %s, so skipping." format systemName, e) (systemName, null) } } .filter(_._2 != null) - .toMap info("Got system producers: %s" format producers.keys) @@ -292,44 +281,48 @@ object SamzaContainer extends Logging { info("Got serdes: %s" format serdes.keys) /* - * A Helper function to build a Map[String, Serde] (systemName -> Serde) for systems defined in the config. This is useful to build both key and message serde maps. + * A Helper function to build a Map[String, Serde] (systemName -> Serde) for systems defined + * in the config. This is useful to build both key and message serde maps. */ val buildSystemSerdeMap = (getSerdeName: (String) => Option[String]) => { systemNames .filter(systemName => getSerdeName(systemName).isDefined) .map(systemName => { val serdeName = getSerdeName(systemName).get - val serde = serdes.getOrElse(serdeName, throw new SamzaException("buildSystemSerdeMap: No class defined for serde: %s." format serdeName)) + val serde = serdes.getOrElse(serdeName, + throw new SamzaException("buildSystemSerdeMap: No class defined for serde: %s." format serdeName)) (systemName, serde) }).toMap } /* - * A Helper function to build a Map[SystemStream, Serde] for streams defined in the config. This is useful to build both key and message serde maps. + * A Helper function to build a Map[SystemStream, Serde] for streams defined in the config. + * This is useful to build both key and message serde maps. */ val buildSystemStreamSerdeMap = (getSerdeName: (SystemStream) => Option[String]) => { (serdeStreams ++ inputSystemStreamPartitions) .filter(systemStream => getSerdeName(systemStream).isDefined) .map(systemStream => { val serdeName = getSerdeName(systemStream).get - val serde = serdes.getOrElse(serdeName, throw new SamzaException("buildSystemStreamSerdeMap: No class defined for serde: %s." format serdeName)) + val serde = serdes.getOrElse(serdeName, + throw new SamzaException("buildSystemStreamSerdeMap: No class defined for serde: %s." format serdeName)) (systemStream, serde) }).toMap } - val systemKeySerdes = buildSystemSerdeMap((systemName: String) => config.getSystemKeySerde(systemName)) + val systemKeySerdes = buildSystemSerdeMap(systemName => config.getSystemKeySerde(systemName)) debug("Got system key serdes: %s" format systemKeySerdes) - val systemMessageSerdes = buildSystemSerdeMap((systemName: String) => config.getSystemMsgSerde(systemName)) + val systemMessageSerdes = buildSystemSerdeMap(systemName => config.getSystemMsgSerde(systemName)) debug("Got system message serdes: %s" format systemMessageSerdes) - val systemStreamKeySerdes = buildSystemStreamSerdeMap((systemStream: SystemStream) => config.getStreamKeySerde(systemStream)) + val systemStreamKeySerdes = buildSystemStreamSerdeMap(systemStream => config.getStreamKeySerde(systemStream)) debug("Got system stream key serdes: %s" format systemStreamKeySerdes) - val systemStreamMessageSerdes = buildSystemStreamSerdeMap((systemStream: SystemStream) => config.getStreamMsgSerde(systemStream)) + val systemStreamMessageSerdes = buildSystemStreamSerdeMap(systemStream => config.getStreamMsgSerde(systemStream)) debug("Got system stream message serdes: %s" format systemStreamMessageSerdes) @@ -378,13 +371,10 @@ object SamzaContainer extends Logging { val coordinatorSystemProducer = new CoordinatorStreamSystemFactory().getCoordinatorStreamSystemProducer(config, samzaContainerMetrics.registry) val localityManager = new LocalityManager(coordinatorSystemProducer) - val checkpointManager = config.getCheckpointManagerFactory() match { - case Some(checkpointFactoryClassName) if (!checkpointFactoryClassName.isEmpty) => - Util - .getObj[CheckpointManagerFactory](checkpointFactoryClassName) - .getCheckpointManager(config, samzaContainerMetrics.registry) - case _ => null - } + val checkpointManager = config.getCheckpointManagerFactory() + .filterNot(_.isEmpty) + .map(Util.getObj[CheckpointManagerFactory](_).getCheckpointManager(config, samzaContainerMetrics.registry)) + .orNull info("Got checkpoint manager: %s" format checkpointManager) // create a map of consumers with callbacks to pass to the OffsetManager @@ -442,8 +432,26 @@ object SamzaContainer extends Logging { val singleThreadMode = config.getSingleThreadMode info("Got single thread mode: " + singleThreadMode) + val taskClassName = config.getTaskClass.orNull + info("Got task class name: %s" format taskClassName) + + if (taskClassName == null && taskFactory == null) { + throw new SamzaException("Either the task class name or the task factory instance is required.") + } + + val isAsyncTask: Boolean = + if (taskFactory != null) { + taskFactory.isInstanceOf[AsyncStreamTaskFactory] + } else { + classOf[AsyncStreamTask].isAssignableFrom(Class.forName(taskClassName)) + } + + if (isAsyncTask) { + info("Got an AsyncStreamTask implementation.") + } + if(singleThreadMode && isAsyncTask) { - throw new SamzaException("AsyncStreamTask %s cannot run on single thread mode." format taskClassName) + throw new SamzaException("AsyncStreamTask cannot run on single thread mode.") } val threadPoolSize = config.getThreadPoolSize @@ -470,12 +478,23 @@ object SamzaContainer extends Logging { val storeWatchPaths = new util.HashSet[Path]() storeWatchPaths.add(defaultStoreBaseDir.toPath) - val taskInstances: Map[TaskName, TaskInstance[_]] = containerModel.getTasks.values.map(taskModel => { + val taskInstances: Map[TaskName, TaskInstance] = containerModel.getTasks.values.map(taskModel => { debug("Setting up task instance: %s" format taskModel) val taskName = taskModel.getTaskName - val taskObj = Class.forName(taskClassName).newInstance + val taskObj = if (taskFactory != null) { + debug("Using task factory to create task instance") + taskFactory match { + case tf: AsyncStreamTaskFactory => tf.createInstance() + case tf: StreamTaskFactory => tf.createInstance() + case _ => + throw new SamzaException("taskFactory must be an instance of StreamTaskFactory or AsyncStreamTaskFactory") + } + } else { + debug("Using task class name: %s to create instance" format taskClassName) + Class.forName(taskClassName).newInstance + } val task = if (!singleThreadMode && !isAsyncTask) // Wrap the StreamTask into a AsyncStreamTask with the build-in thread pool @@ -491,18 +510,21 @@ object SamzaContainer extends Logging { .map { case (storeName, changeLogSystemStream) => val systemConsumer = systemFactories - .getOrElse(changeLogSystemStream.getSystem, throw new SamzaException("Changelog system %s for store %s does not exist in the config." format (changeLogSystemStream, storeName))) + .getOrElse(changeLogSystemStream.getSystem, + throw new SamzaException("Changelog system %s for store %s does not " + + "exist in the config." format (changeLogSystemStream, storeName))) .getConsumer(changeLogSystemStream.getSystem, config, taskInstanceMetrics.registry) samzaContainerMetrics.addStoreRestorationGauge(taskName, storeName) (storeName, systemConsumer) - }.toMap + } info("Got store consumers: %s" format storeConsumers) var loggedStorageBaseDir: File = null if(System.getenv(ShellCommandConfig.ENV_LOGGED_STORE_BASE_DIR) != null) { val jobNameAndId = Util.getJobNameAndId(config) - loggedStorageBaseDir = new File(System.getenv(ShellCommandConfig.ENV_LOGGED_STORE_BASE_DIR) + File.separator + jobNameAndId._1 + "-" + jobNameAndId._2) + loggedStorageBaseDir = new File(System.getenv(ShellCommandConfig.ENV_LOGGED_STORE_BASE_DIR) + + File.separator + jobNameAndId._1 + "-" + jobNameAndId._2) } else { warn("No override was provided for logged store base directory. This disables local state re-use on " + "application restart. If you want to enable this feature, set LOGGED_STORE_BASE_DIR as an environment " + @@ -523,11 +545,13 @@ object SamzaContainer extends Logging { null } val keySerde = config.getStorageKeySerde(storeName) match { - case Some(keySerde) => serdes.getOrElse(keySerde, throw new SamzaException("StorageKeySerde: No class defined for serde: %s." format keySerde)) + case Some(keySerde) => serdes.getOrElse(keySerde, + throw new SamzaException("StorageKeySerde: No class defined for serde: %s." format keySerde)) case _ => null } val msgSerde = config.getStorageMsgSerde(storeName) match { - case Some(msgSerde) => serdes.getOrElse(msgSerde, throw new SamzaException("StorageMsgSerde: No class defined for serde: %s." format msgSerde)) + case Some(msgSerde) => serdes.getOrElse(msgSerde, + throw new SamzaException("StorageMsgSerde: No class defined for serde: %s." format msgSerde)) case _ => null } val storeBaseDir = if(changeLogSystemStreamPartition != null) { @@ -568,7 +592,7 @@ object SamzaContainer extends Logging { info("Retrieved SystemStreamPartitions " + systemStreamPartitions + " for " + taskName) - def createTaskInstance[T] (task: T ): TaskInstance[T] = new TaskInstance[T]( + def createTaskInstance(task: Any): TaskInstance = new TaskInstance( task = task, taskName = taskName, config = config, @@ -659,7 +683,7 @@ object SamzaContainer extends Logging { class SamzaContainer( containerContext: SamzaContainerContext, - taskInstances: Map[TaskName, TaskInstance[_]], + taskInstances: Map[TaskName, TaskInstance], runLoop: Runnable, consumerMultiplexer: SystemConsumers, producerMultiplexer: SystemProducers, @@ -675,6 +699,17 @@ class SamzaContainer( taskThreadPool: ExecutorService = null) extends Runnable with Logging { val shutdownMs = containerContext.config.getShutdownMs.getOrElse(5000L) + private val runLoopStartLatch: CountDownLatch = new CountDownLatch(1) + + def awaitStart(timeoutMs: Long): Boolean = { + try { + runLoopStartLatch.await(timeoutMs, TimeUnit.MILLISECONDS) + } catch { + case ie: InterruptedException => + error("Interrupted while waiting for runloop to start!", ie) + throw ie + } + } def run { try { @@ -691,8 +726,9 @@ class SamzaContainer( startConsumers startSecurityManger - info("Entering run loop.") addShutdownHook + runLoopStartLatch.countDown() + info("Entering run loop.") runLoop.run } catch { case e: Exception => @@ -716,6 +752,13 @@ class SamzaContainer( } } + def shutdown() = { + runLoop match { + case runLoop: RunLoop => runLoop.shutdown + case asyncRunLoop: AsyncRunLoop => asyncRunLoop.shutdown() + } + } + def startDiskSpaceMonitor: Unit = { if (diskSpaceMonitor != null) { info("Starting disk space monitor") @@ -773,9 +816,11 @@ class SamzaContainer( localityManager.writeContainerToHostMapping(containerContext.id, hostInet.getHostName, jmxUrl, jmxTunnelingUrl) } catch { case uhe: UnknownHostException => - warn("Received UnknownHostException when persisting locality info for container %d: %s" format (containerContext.id, uhe.getMessage)) //No-op + warn("Received UnknownHostException when persisting locality info for container %d: " + + "%s" format (containerContext.id, uhe.getMessage)) //No-op case unknownException: Throwable => - warn("Received an exception when persisting locality info for container %d: %s" format (containerContext.id, unknownException.getMessage)) + warn("Received an exception when persisting locality info for container %d: " + + "%s" format (containerContext.id, unknownException.getMessage)) } } } http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala index 26a8f5f..e07fcf4 100644 --- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala +++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala @@ -43,8 +43,8 @@ import org.apache.samza.util.Logging import scala.collection.JavaConversions._ -class TaskInstance[T]( - task: T, +class TaskInstance( + task: Any, val taskName: TaskName, config: Config, val metrics: TaskInstanceMetrics, @@ -84,7 +84,8 @@ class TaskInstance[T]( // store the (ssp -> if this ssp is catched up) mapping. "catched up" // means the same ssp in other taskInstances have the same offset as // the one here. - var ssp2catchedupMapping: scala.collection.mutable.Map[SystemStreamPartition, Boolean] = scala.collection.mutable.Map[SystemStreamPartition, Boolean]() + var ssp2catchedupMapping: scala.collection.mutable.Map[SystemStreamPartition, Boolean] = + scala.collection.mutable.Map[SystemStreamPartition, Boolean]() systemStreamPartitions.foreach(ssp2catchedupMapping += _ -> false) def registerMetrics { @@ -140,7 +141,8 @@ class TaskInstance[T]( }) } - def process(envelope: IncomingMessageEnvelope, coordinator: ReadableCoordinator, callbackFactory: TaskCallbackFactory = null) { + def process(envelope: IncomingMessageEnvelope, coordinator: ReadableCoordinator, + callbackFactory: TaskCallbackFactory = null) { metrics.processes.inc if (!ssp2catchedupMapping.getOrElse(envelope.getSystemStreamPartition, @@ -151,7 +153,8 @@ class TaskInstance[T]( if (ssp2catchedupMapping(envelope.getSystemStreamPartition)) { metrics.messagesActuallyProcessed.inc - trace("Processing incoming message envelope for taskName and SSP: %s, %s" format (taskName, envelope.getSystemStreamPartition)) + trace("Processing incoming message envelope for taskName and SSP: %s, %s" + format (taskName, envelope.getSystemStreamPartition)) if (isAsyncTask) { exceptionHandler.maybeHandle { @@ -163,7 +166,8 @@ class TaskInstance[T]( task.asInstanceOf[StreamTask].process(envelope, collector, coordinator) } - trace("Updating offset map for taskName, SSP and offset: %s, %s, %s" format (taskName, envelope.getSystemStreamPartition, envelope.getOffset)) + trace("Updating offset map for taskName, SSP and offset: %s, %s, %s" + format (taskName, envelope.getSystemStreamPartition, envelope.getOffset)) offsetManager.update(taskName, envelope.getSystemStreamPartition, envelope.getOffset) } @@ -173,7 +177,7 @@ class TaskInstance[T]( def endOfStream(coordinator: ReadableCoordinator): Unit = { if (isEndOfStreamListenerTask) { exceptionHandler.maybeHandle { - task.asInstanceOf[EndOfStreamListenerTask].onEndOfStream(collector, coordinator); + task.asInstanceOf[EndOfStreamListenerTask].onEndOfStream(collector, coordinator) } } } @@ -230,7 +234,8 @@ class TaskInstance[T]( override def toString() = "TaskInstance for class %s and taskName %s." format (task.getClass.getName, taskName) - def toDetailedString() = "TaskInstance [taskName = %s, windowable=%s, closable=%s endofstreamlistener=%s]" format (taskName, isWindowableTask, isClosableTask, isEndOfStreamListenerTask) + def toDetailedString() = "TaskInstance [taskName = %s, windowable=%s, closable=%s endofstreamlistener=%s]" format + (taskName, isWindowableTask, isClosableTask, isEndOfStreamListenerTask) /** * From the envelope, check if this SSP has catched up with the starting offset of the SSP http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala index 9ccf6fc..0d5adee 100644 --- a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala +++ b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala @@ -49,22 +49,14 @@ class ThreadJobFactory extends StreamJobFactory with Logging { try { coordinator.start - new ThreadJob(new Runnable { - override def run(): Unit = { - val jmxServer = new JmxServer - try { + new ThreadJob( SamzaContainer( containerModel.getContainerId, containerModel, config, jobModel.maxChangeLogStreamPartitions, null, - new JmxServer) - } finally { - jmxServer.stop - } - } - }) + new JmxServer)) } finally { coordinator.stop } http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java b/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java index 58975d3..798977b 100644 --- a/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java +++ b/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java @@ -60,7 +60,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class TestAsyncRunLoop { - Map<TaskName, TaskInstance<AsyncStreamTask>> tasks; + Map<TaskName, TaskInstance> tasks; ExecutorService executor; SystemConsumers consumerMultiplexer; SamzaContainerMetrics containerMetrics; @@ -87,8 +87,8 @@ public class TestAsyncRunLoop { TestTask task0; TestTask task1; - TaskInstance<AsyncStreamTask> t0; - TaskInstance<AsyncStreamTask> t1; + TaskInstance t0; + TaskInstance t1; AsyncRunLoop createRunLoop() { return new AsyncRunLoop(tasks, @@ -103,15 +103,15 @@ public class TestAsyncRunLoop { () -> 0L); } - TaskInstance<AsyncStreamTask> createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) { + TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) { TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap()); scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConversions.asScalaSet(Collections.singleton(ssp)).toSet(); - return new TaskInstance<AsyncStreamTask>(task, taskName, mock(Config.class), taskInstanceMetrics, + return new TaskInstance(task, taskName, mock(Config.class), taskInstanceMetrics, null, consumers, mock(TaskInstanceCollector.class), mock(SamzaContainerContext.class), manager, null, null, sspSet, new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>())); } - TaskInstance<AsyncStreamTask> createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp) { + TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp) { return createTaskInstance(task, taskName, ssp, offsetManager, consumerMultiplexer); } @@ -466,7 +466,7 @@ public class TestAsyncRunLoop { sspMap.put(ssp2, messageList); SystemConsumer mockConsumer = mock(SystemConsumer.class); - when(mockConsumer.poll((Set<SystemStreamPartition>) anyObject(), anyLong())).thenReturn(sspMap); + when(mockConsumer.poll(anyObject(), anyLong())).thenReturn(sspMap); HashMap<String, SystemConsumer> systemConsumerMap = new HashMap<>(); systemConsumerMap.put("system1", mockConsumer); @@ -485,9 +485,9 @@ public class TestAsyncRunLoop { when(offsetManager.getStartingOffset(taskName1, ssp1)).thenReturn(Option.apply(IncomingMessageEnvelope.END_OF_STREAM_OFFSET)); when(offsetManager.getStartingOffset(taskName2, ssp2)).thenReturn(Option.apply("1")); - TaskInstance<AsyncStreamTask> taskInstance1 = createTaskInstance(mockStreamTask1, taskName1, ssp1, offsetManager, consumers); - TaskInstance<AsyncStreamTask> taskInstance2 = createTaskInstance(mockStreamTask2, taskName2, ssp2, offsetManager, consumers); - Map<TaskName, TaskInstance<AsyncStreamTask>> tasks = new HashMap<>(); + TaskInstance taskInstance1 = createTaskInstance(mockStreamTask1, taskName1, ssp1, offsetManager, consumers); + TaskInstance taskInstance2 = createTaskInstance(mockStreamTask2, taskName2, ssp2, offsetManager, consumers); + Map<TaskName, TaskInstance> tasks = new HashMap<>(); tasks.put(taskName1, taskInstance1); tasks.put(taskName2, taskInstance2); http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala b/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala index d83c7e2..0304fc8 100644 --- a/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala +++ b/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala @@ -31,7 +31,6 @@ import org.apache.samza.system.SystemConsumers import org.apache.samza.system.SystemStreamPartition import org.apache.samza.task.TaskCoordinator.RequestScope import org.apache.samza.task.ReadableCoordinator -import org.apache.samza.task.StreamTask import org.apache.samza.util.Clock import org.junit.Assert._ import org.junit.Test @@ -55,12 +54,12 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat val envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0") val envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1") - def getMockTaskInstances: Map[TaskName, TaskInstance[StreamTask]] = { - val ti0 = mock[TaskInstance[StreamTask]] + def getMockTaskInstances: Map[TaskName, TaskInstance] = { + val ti0 = mock[TaskInstance] when(ti0.systemStreamPartitions).thenReturn(Set(ssp0)) when(ti0.taskName).thenReturn(taskName0) - val ti1 = mock[TaskInstance[StreamTask]] + val ti1 = mock[TaskInstance] when(ti1.systemStreamPartitions).thenReturn(Set(ssp1)) when(ti1.taskName).thenReturn(taskName1) @@ -183,7 +182,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat def anyObject[T] = Matchers.anyObject.asInstanceOf[T] // Stub out TaskInstance.process. Mockito really doesn't make this easy. :( - def stubProcess(taskInstance: TaskInstance[StreamTask], process: (IncomingMessageEnvelope, ReadableCoordinator) => Unit) { + def stubProcess(taskInstance: TaskInstance, process: (IncomingMessageEnvelope, ReadableCoordinator) => Unit) { when(taskInstance.process(anyObject, anyObject, anyObject)).thenAnswer(new Answer[Unit]() { override def answer(invocation: InvocationOnMock) { val envelope = invocation.getArguments()(0).asInstanceOf[IncomingMessageEnvelope] @@ -276,9 +275,9 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat @Test def testGetSystemStreamPartitionToTaskInstancesMapping { - val ti0 = mock[TaskInstance[StreamTask]] - val ti1 = mock[TaskInstance[StreamTask]] - val ti2 = mock[TaskInstance[StreamTask]] + val ti0 = mock[TaskInstance] + val ti1 = mock[TaskInstance] + val ti2 = mock[TaskInstance] when(ti0.systemStreamPartitions).thenReturn(Set(ssp0)) when(ti1.systemStreamPartitions).thenReturn(Set(ssp1)) when(ti2.systemStreamPartitions).thenReturn(Set(ssp1)) http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala index 5895037..07055a0 100644 --- a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala +++ b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala @@ -180,7 +180,7 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar { new SerdeManager) val collector = new TaskInstanceCollector(producerMultiplexer) val containerContext = new SamzaContainerContext(0, config, Set[TaskName](taskName)) - val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask]( + val taskInstance: TaskInstance = new TaskInstance( task, taskName, config, @@ -262,7 +262,7 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar { } }) - val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask]( + val taskInstance: TaskInstance = new TaskInstance( task, taskName, config, http://git-wip-us.apache.org/repos/asf/samza/blob/be6be8cb/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala index 3c83529..7e35525 100644 --- a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala +++ b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala @@ -71,7 +71,7 @@ class TestTaskInstance { val taskName = new TaskName("taskName") val collector = new TaskInstanceCollector(producerMultiplexer) val containerContext = new SamzaContainerContext(0, config, Set[TaskName](taskName)) - val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask]( + val taskInstance: TaskInstance = new TaskInstance( task, taskName, config, @@ -169,7 +169,7 @@ class TestTaskInstance { val registry = new MetricsRegistryMap val taskMetrics = new TaskInstanceMetrics(registry = registry) - val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask]( + val taskInstance = new TaskInstance( task, taskName, config, @@ -226,7 +226,7 @@ class TestTaskInstance { val registry = new MetricsRegistryMap val taskMetrics = new TaskInstanceMetrics(registry = registry) - val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask]( + val taskInstance = new TaskInstance( task, taskName, config,
