http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala index 18c0922..b8600d5 100644 --- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala +++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala @@ -20,12 +20,19 @@ package org.apache.samza.container import java.io.File +import java.lang.Thread.UncaughtExceptionHandler +import java.net.URL +import java.net.UnknownHostException import java.nio.file.Path import java.util -import java.lang.Thread.UncaughtExceptionHandler -import java.net.{URL, UnknownHostException} +import java.util.concurrent.ExecutorService +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit + import org.apache.samza.SamzaException -import org.apache.samza.checkpoint.{CheckpointManagerFactory, OffsetManager, OffsetManagerMetrics} +import org.apache.samza.checkpoint.CheckpointManagerFactory +import org.apache.samza.checkpoint.OffsetManager +import org.apache.samza.checkpoint.OffsetManagerMetrics import org.apache.samza.config.JobConfig.Config2Job import org.apache.samza.config.MetricsConfig.Config2Metrics import org.apache.samza.config.SerializerConfig.Config2Serializer @@ -34,18 +41,45 @@ import org.apache.samza.config.StorageConfig.Config2Storage import org.apache.samza.config.StreamConfig.Config2Stream import org.apache.samza.config.SystemConfig.Config2System import org.apache.samza.config.TaskConfig.Config2Task +import org.apache.samza.container.disk.DiskQuotaPolicyFactory +import org.apache.samza.container.disk.DiskSpaceMonitor import org.apache.samza.container.disk.DiskSpaceMonitor.Listener -import org.apache.samza.container.disk.{NoThrottlingDiskQuotaPolicyFactory, DiskQuotaPolicyFactory, PollingScanDiskSpaceMonitor, DiskSpaceMonitor} +import org.apache.samza.container.disk.NoThrottlingDiskQuotaPolicyFactory +import org.apache.samza.container.disk.PollingScanDiskSpaceMonitor import org.apache.samza.coordinator.stream.CoordinatorStreamSystemFactory -import org.apache.samza.job.model.{ContainerModel, JobModel} -import org.apache.samza.metrics.{JmxServer, JvmMetrics, MetricsRegistryMap, MetricsReporter, MetricsReporterFactory} -import org.apache.samza.serializers.{SerdeFactory, SerdeManager} +import org.apache.samza.job.model.ContainerModel +import org.apache.samza.job.model.JobModel +import org.apache.samza.metrics.JmxServer +import org.apache.samza.metrics.JvmMetrics +import org.apache.samza.metrics.MetricsRegistryMap +import org.apache.samza.metrics.MetricsReporter +import org.apache.samza.metrics.MetricsReporterFactory +import org.apache.samza.serializers.SerdeFactory +import org.apache.samza.serializers.SerdeManager import org.apache.samza.serializers.model.SamzaObjectMapper -import org.apache.samza.storage.{StorageEngineFactory, TaskStorageManager} -import org.apache.samza.system.{StreamMetadataCache, SystemConsumers, SystemConsumersMetrics, SystemFactory, SystemProducers, SystemProducersMetrics, SystemStream, SystemStreamPartition} -import org.apache.samza.system.chooser.{DefaultChooser, MessageChooserFactory, RoundRobinChooserFactory} -import org.apache.samza.task.{StreamTask, TaskInstanceCollector} -import org.apache.samza.util.{ThrottlingExecutor, ExponentialSleepStrategy, Logging, Util} +import org.apache.samza.storage.StorageEngineFactory +import org.apache.samza.storage.TaskStorageManager +import org.apache.samza.system.StreamMetadataCache +import org.apache.samza.system.SystemConsumers +import org.apache.samza.system.SystemConsumersMetrics +import org.apache.samza.system.SystemFactory +import org.apache.samza.system.SystemProducers +import org.apache.samza.system.SystemProducersMetrics +import org.apache.samza.system.SystemStream +import org.apache.samza.system.SystemStreamPartition +import org.apache.samza.system.chooser.DefaultChooser +import org.apache.samza.system.chooser.MessageChooserFactory +import org.apache.samza.system.chooser.RoundRobinChooserFactory +import org.apache.samza.task.AsyncRunLoop +import org.apache.samza.task.AsyncStreamTask +import org.apache.samza.task.AsyncStreamTaskAdapter +import org.apache.samza.task.StreamTask +import org.apache.samza.task.TaskInstanceCollector +import org.apache.samza.util.ExponentialSleepStrategy +import org.apache.samza.util.Logging +import org.apache.samza.util.ThrottlingExecutor +import org.apache.samza.util.Util + import scala.collection.JavaConversions._ object SamzaContainer extends Logging { @@ -164,6 +198,12 @@ object SamzaContainer extends Logging { info("Got input stream metadata: %s" format inputStreamMetadata) + val taskClassName = config + .getTaskClass + .getOrElse(throw new SamzaException("No task class defined in configuration.")) + + info("Got stream task class: %s" format taskClassName) + val consumers = inputSystems .map(systemName => { val systemFactory = systemFactories(systemName) @@ -181,6 +221,9 @@ object SamzaContainer extends Logging { info("Got system consumers: %s" format consumers.keys) + val isAsyncTask = classOf[AsyncStreamTask].isAssignableFrom(Class.forName(taskClassName)) + info("%s is AsyncStreamTask" format taskClassName) + val producers = systemFactories .map { case (systemName, systemFactory) => @@ -360,26 +403,22 @@ object SamzaContainer extends Logging { info("Got storage engines: %s" format storageEngineFactories.keys) - val taskClassName = config - .getTaskClass - .getOrElse(throw new SamzaException("No task class defined in configuration.")) - - info("Got stream task class: %s" format taskClassName) - - val taskWindowMs = config.getWindowMs.getOrElse(-1L) - - info("Got window milliseconds: %s" format taskWindowMs) + val singleThreadMode = config.getSingleThreadMode + info("Got single thread mode: " + singleThreadMode) - val taskCommitMs = config.getCommitMs.getOrElse(60000L) - - info("Got commit milliseconds: %s" format taskCommitMs) + if(singleThreadMode && isAsyncTask) { + throw new SamzaException("AsyncStreamTask %s cannot run on single thread mode." format taskClassName) + } - val taskShutdownMs = config.getShutdownMs.getOrElse(5000L) + val threadPoolSize = config.getThreadPoolSize + info("Got thread pool size: " + threadPoolSize) - info("Got shutdown timeout milliseconds: %s" format taskShutdownMs) + val taskThreadPool = if (!singleThreadMode && threadPoolSize > 0) + Executors.newFixedThreadPool(threadPoolSize) + else + null // Wire up all task-instance-level (unshared) objects. - val taskNames = containerModel .getTasks .values @@ -395,12 +434,18 @@ object SamzaContainer extends Logging { val storeWatchPaths = new util.HashSet[Path]() storeWatchPaths.add(defaultStoreBaseDir.toPath) - val taskInstances: Map[TaskName, TaskInstance] = containerModel.getTasks.values.map(taskModel => { + val taskInstances: Map[TaskName, TaskInstance[_]] = containerModel.getTasks.values.map(taskModel => { debug("Setting up task instance: %s" format taskModel) val taskName = taskModel.getTaskName - val task = Util.getObj[StreamTask](taskClassName) + val taskObj = Class.forName(taskClassName).newInstance + + val task = if (!singleThreadMode && !isAsyncTask) + // Wrap the StreamTask into a AsyncStreamTask with the build-in thread pool + new AsyncStreamTaskAdapter(taskObj.asInstanceOf[StreamTask], taskThreadPool) + else + taskObj val taskInstanceMetrics = new TaskInstanceMetrics("TaskName-%s" format taskName) @@ -487,20 +532,22 @@ object SamzaContainer extends Logging { info("Retrieved SystemStreamPartitions " + systemStreamPartitions + " for " + taskName) - val taskInstance = new TaskInstance( - task = task, - taskName = taskName, - config = config, - metrics = taskInstanceMetrics, - systemAdmins = systemAdmins, - consumerMultiplexer = consumerMultiplexer, - collector = collector, - containerContext = containerContext, - offsetManager = offsetManager, - storageManager = storageManager, - reporters = reporters, - systemStreamPartitions = systemStreamPartitions, - exceptionHandler = TaskInstanceExceptionHandler(taskInstanceMetrics, config)) + def createTaskInstance[T] (task: T ): TaskInstance[T] = new TaskInstance[T]( + task = task, + taskName = taskName, + config = config, + metrics = taskInstanceMetrics, + systemAdmins = systemAdmins, + consumerMultiplexer = consumerMultiplexer, + collector = collector, + containerContext = containerContext, + offsetManager = offsetManager, + storageManager = storageManager, + reporters = reporters, + systemStreamPartitions = systemStreamPartitions, + exceptionHandler = TaskInstanceExceptionHandler(taskInstanceMetrics, config)) + + val taskInstance = createTaskInstance(task) (taskName, taskInstance) }).toMap @@ -533,14 +580,13 @@ object SamzaContainer extends Logging { info(s"Disk quotas disabled because polling interval is not set ($DISK_POLL_INTERVAL_KEY)") } - val runLoop = new RunLoop( - taskInstances = taskInstances, - consumerMultiplexer = consumerMultiplexer, - metrics = samzaContainerMetrics, - windowMs = taskWindowMs, - commitMs = taskCommitMs, - shutdownMs = taskShutdownMs, - executor = executor) + val runLoop = RunLoopFactory.createRunLoop( + taskInstances, + consumerMultiplexer, + taskThreadPool, + executor, + samzaContainerMetrics, + config) info("Samza container setup complete.") @@ -557,14 +603,15 @@ object SamzaContainer extends Logging { reporters = reporters, jvm = jvm, jmxServer = jmxServer, - diskSpaceMonitor = diskSpaceMonitor) + diskSpaceMonitor = diskSpaceMonitor, + taskThreadPool = taskThreadPool) } } class SamzaContainer( containerContext: SamzaContainerContext, - taskInstances: Map[TaskName, TaskInstance], - runLoop: RunLoop, + taskInstances: Map[TaskName, TaskInstance[_]], + runLoop: Runnable, consumerMultiplexer: SystemConsumers, producerMultiplexer: SystemProducers, metrics: SamzaContainerMetrics, @@ -574,7 +621,10 @@ class SamzaContainer( localityManager: LocalityManager = null, securityManager: SecurityManager = null, reporters: Map[String, MetricsReporter] = Map(), - jvm: JvmMetrics = null) extends Runnable with Logging { + jvm: JvmMetrics = null, + taskThreadPool: ExecutorService = null) extends Runnable with Logging { + + val shutdownMs = containerContext.config.getShutdownMs.getOrElse(5000L) def run { try { @@ -591,6 +641,7 @@ class SamzaContainer( startSecurityManger info("Entering run loop.") + addShutdownHook runLoop.run } catch { case e: Exception => @@ -710,7 +761,7 @@ class SamzaContainer( consumerMultiplexer.start } - def startSecurityManger: Unit = { + def startSecurityManger { if (securityManager != null) { info("Starting security manager.") @@ -718,6 +769,25 @@ class SamzaContainer( } } + def addShutdownHook { + val runLoopThread = Thread.currentThread() + Runtime.getRuntime().addShutdownHook(new Thread() { + override def run() = { + info("Shutting down, will wait up to %s ms" format shutdownMs) + runLoop match { + case runLoop: RunLoop => runLoop.shutdown + case asyncRunLoop: AsyncRunLoop => asyncRunLoop.shutdown() + } + runLoopThread.join(shutdownMs) + if (runLoopThread.isAlive) { + warn("Did not shut down within %s ms, exiting" format shutdownMs) + } else { + info("Shutdown complete") + } + } + }) + } + def shutdownConsumers { info("Shutting down consumer multiplexer.") @@ -733,6 +803,19 @@ class SamzaContainer( def shutdownTask { info("Shutting down task instance stream tasks.") + + if (taskThreadPool != null) { + info("Shutting down task thread pool") + try { + taskThreadPool.shutdown() + if(taskThreadPool.awaitTermination(shutdownMs, TimeUnit.MILLISECONDS)) { + taskThreadPool.shutdownNow() + } + } catch { + case e: Exception => error(e.getMessage, e) + } + } + taskInstances.values.foreach(_.shutdownTask) }
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala index 2044ce0..e3891cf 100644 --- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala +++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala @@ -34,9 +34,11 @@ class SamzaContainerMetrics( val envelopes = newCounter("process-envelopes") val nullEnvelopes = newCounter("process-null-envelopes") val chooseNs = newTimer("choose-ns") + val chooserUpdateNs = newTimer("chooser-update-ns") val windowNs = newTimer("window-ns") val processNs = newTimer("process-ns") val commitNs = newTimer("commit-ns") + val blockNs = newTimer("block-ns") val utilization = newGauge("event-loop-utilization", 0.0F) val diskUsageBytes = newGauge("disk-usage-bytes", 0L) val diskQuotaBytes = newGauge("disk-quota-bytes", Long.MaxValue) http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala index d32a929..89f6857 100644 --- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala +++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala @@ -19,36 +19,39 @@ package org.apache.samza.container + import org.apache.samza.SamzaException import org.apache.samza.checkpoint.OffsetManager import org.apache.samza.config.Config -import org.apache.samza.config.TaskConfig.Config2Task import org.apache.samza.metrics.MetricsReporter import org.apache.samza.storage.TaskStorageManager import org.apache.samza.system.IncomingMessageEnvelope -import org.apache.samza.system.SystemStreamPartition +import org.apache.samza.system.SystemAdmin import org.apache.samza.system.SystemConsumers -import org.apache.samza.task.TaskContext +import org.apache.samza.system.SystemStreamPartition +import org.apache.samza.task.AsyncStreamTask import org.apache.samza.task.ClosableTask import org.apache.samza.task.InitableTask -import org.apache.samza.task.WindowableTask -import org.apache.samza.task.StreamTask import org.apache.samza.task.ReadableCoordinator +import org.apache.samza.task.StreamTask +import org.apache.samza.task.TaskCallbackFactory +import org.apache.samza.task.TaskContext import org.apache.samza.task.TaskInstanceCollector +import org.apache.samza.task.WindowableTask import org.apache.samza.util.Logging + import scala.collection.JavaConversions._ -import org.apache.samza.system.SystemAdmin -class TaskInstance( - task: StreamTask, +class TaskInstance[T]( + task: T, val taskName: TaskName, config: Config, - metrics: TaskInstanceMetrics, + val metrics: TaskInstanceMetrics, systemAdmins: Map[String, SystemAdmin], consumerMultiplexer: SystemConsumers, collector: TaskInstanceCollector, containerContext: SamzaContainerContext, - offsetManager: OffsetManager = new OffsetManager, + val offsetManager: OffsetManager = new OffsetManager, storageManager: TaskStorageManager = null, reporters: Map[String, MetricsReporter] = Map(), val systemStreamPartitions: Set[SystemStreamPartition] = Set(), @@ -56,6 +59,8 @@ class TaskInstance( val isInitableTask = task.isInstanceOf[InitableTask] val isWindowableTask = task.isInstanceOf[WindowableTask] val isClosableTask = task.isInstanceOf[ClosableTask] + val isAsyncTask = task.isInstanceOf[AsyncStreamTask] + val context = new TaskContext { def getMetricsRegistry = metrics.registry def getSystemStreamPartitions = systemStreamPartitions @@ -133,7 +138,7 @@ class TaskInstance( }) } - def process(envelope: IncomingMessageEnvelope, coordinator: ReadableCoordinator) { + def process(envelope: IncomingMessageEnvelope, coordinator: ReadableCoordinator, callbackFactory: TaskCallbackFactory = null) { metrics.processes.inc if (!ssp2catchedupMapping.getOrElse(envelope.getSystemStreamPartition, @@ -146,13 +151,20 @@ class TaskInstance( trace("Processing incoming message envelope for taskName and SSP: %s, %s" format (taskName, envelope.getSystemStreamPartition)) - exceptionHandler.maybeHandle { - task.process(envelope, collector, coordinator) - } + if (isAsyncTask) { + exceptionHandler.maybeHandle { + val callback = callbackFactory.createCallback() + task.asInstanceOf[AsyncStreamTask].processAsync(envelope, collector, coordinator, callback) + } + } else { + exceptionHandler.maybeHandle { + task.asInstanceOf[StreamTask].process(envelope, collector, coordinator) + } - trace("Updating offset map for taskName, SSP and offset: %s, %s, %s" format (taskName, envelope.getSystemStreamPartition, envelope.getOffset)) + trace("Updating offset map for taskName, SSP and offset: %s, %s, %s" format (taskName, envelope.getSystemStreamPartition, envelope.getOffset)) - offsetManager.update(taskName, envelope.getSystemStreamPartition, envelope.getOffset) + offsetManager.update(taskName, envelope.getSystemStreamPartition, envelope.getOffset) + } } } http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala index 8b86388..7bedadf 100644 --- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala +++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala @@ -35,6 +35,8 @@ class TaskInstanceMetrics( val sends = newCounter("send-calls") val flushes = newCounter("flush-calls") val messagesSent = newCounter("messages-sent") + val pendingMessages = newGauge("pending-messages", 0) + val messagesInFlight = newGauge("messages-in-flight", 0) def addOffsetGauge(systemStreamPartition: SystemStreamPartition, getValue: () => String) { newGauge("%s-%s-%d-offset" format (systemStreamPartition.getSystem, systemStreamPartition.getStream, systemStreamPartition.getPartition.getPartitionId), getValue) http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala b/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala index d3bd9b7..ba38b5c 100644 --- a/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala +++ b/samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala @@ -71,9 +71,12 @@ object JobModelManager extends Logging { coordinatorSystemConsumer.start debug("Bootstrapping coordinator system stream.") coordinatorSystemConsumer.bootstrap + val source = "Job-coordinator" + coordinatorSystemProducer.register(source) + info("Registering coordinator system stream producer.") val config = coordinatorSystemConsumer.getConfig info("Got config: %s" format config) - val changelogManager = new ChangelogPartitionManager(coordinatorSystemProducer, coordinatorSystemConsumer, "Job-coordinator") + val changelogManager = new ChangelogPartitionManager(coordinatorSystemProducer, coordinatorSystemConsumer, source) val localityManager = new LocalityManager(coordinatorSystemProducer, coordinatorSystemConsumer) val systemNames = getSystemNames(config) http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala b/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala index 2efe836..a8355b9 100644 --- a/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala +++ b/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala @@ -99,7 +99,7 @@ class SystemConsumers ( * with no remaining unprocessed messages, the SystemConsumers will poll for * it within 50ms of its availability in the stream system.</p> */ - pollIntervalMs: Int, + val pollIntervalMs: Int, /** * Clock can be used to inject a custom clock when mocking this class in @@ -203,28 +203,31 @@ class SystemConsumers ( } } - def choose: IncomingMessageEnvelope = { + def choose (updateChooser: Boolean = true): IncomingMessageEnvelope = { val envelopeFromChooser = chooser.choose updateTimer(metrics.deserializationNs) { if (envelopeFromChooser == null) { - trace("Chooser returned null.") + trace("Chooser returned null.") - metrics.choseNull.inc + metrics.choseNull.inc - // Sleep for a while so we don't poll in a tight loop. - timeout = noNewMessagesTimeout + // Sleep for a while so we don't poll in a tight loop. + timeout = noNewMessagesTimeout } else { - val systemStreamPartition = envelopeFromChooser.getSystemStreamPartition + val systemStreamPartition = envelopeFromChooser.getSystemStreamPartition - trace("Chooser returned an incoming message envelope: %s" format envelopeFromChooser) + trace("Chooser returned an incoming message envelope: %s" format envelopeFromChooser) - // Ok to give the chooser a new message from this stream. - timeout = 0 - metrics.choseObject.inc - metrics.systemStreamMessagesChosen(envelopeFromChooser.getSystemStreamPartition).inc + // Ok to give the chooser a new message from this stream. + timeout = 0 + metrics.choseObject.inc + metrics.systemStreamMessagesChosen(envelopeFromChooser.getSystemStreamPartition).inc - tryUpdate(systemStreamPartition) + if (updateChooser) { + trace("Update chooser for " + systemStreamPartition.getPartition) + tryUpdate(systemStreamPartition) + } } } @@ -287,7 +290,7 @@ class SystemConsumers ( } } - private def tryUpdate(ssp: SystemStreamPartition) { + def tryUpdate(ssp: SystemStreamPartition) { var updated = false try { updated = update(ssp) http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java b/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java new file mode 100644 index 0000000..ca913de --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.samza.task; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.samza.Partition; +import org.apache.samza.checkpoint.OffsetManager; +import org.apache.samza.config.Config; +import org.apache.samza.container.SamzaContainerContext; +import org.apache.samza.container.SamzaContainerMetrics; +import org.apache.samza.container.TaskInstance; +import org.apache.samza.container.TaskInstanceExceptionHandler; +import org.apache.samza.container.TaskInstanceMetrics; +import org.apache.samza.container.TaskName; +import org.apache.samza.metrics.MetricsRegistryMap; +import org.apache.samza.system.IncomingMessageEnvelope; +import org.apache.samza.system.SystemConsumers; +import org.apache.samza.system.SystemStreamPartition; +import org.junit.Before; +import org.junit.Test; +import scala.collection.JavaConversions; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class TestAsyncRunLoop { + + Map<TaskName, TaskInstance<AsyncStreamTask>> tasks; + ExecutorService executor; + SystemConsumers consumerMultiplexer; + SamzaContainerMetrics containerMetrics; + OffsetManager offsetManager; + long windowMs; + long commitMs; + long callbackTimeoutMs; + int maxMessagesInFlight; + TaskCoordinator.RequestScope commitRequest; + TaskCoordinator.RequestScope shutdownRequest; + + Partition p0 = new Partition(0); + Partition p1 = new Partition(1); + TaskName taskName0 = new TaskName(p0.toString()); + TaskName taskName1 = new TaskName(p1.toString()); + SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0); + SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1); + IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0"); + IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1"); + IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0"); + + TestTask task0; + TestTask task1; + TaskInstance<AsyncStreamTask> t0; + TaskInstance<AsyncStreamTask> t1; + + AsyncRunLoop createRunLoop() { + return new AsyncRunLoop(tasks, + executor, + consumerMultiplexer, + maxMessagesInFlight, + windowMs, + commitMs, + callbackTimeoutMs, + containerMetrics); + } + + TaskInstance<AsyncStreamTask> createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp) { + TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap()); + scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConversions.asScalaSet(Collections.singleton(ssp)).toSet(); + return new TaskInstance<AsyncStreamTask>(task, taskName, mock(Config.class), taskInstanceMetrics, + null, consumerMultiplexer, mock(TaskInstanceCollector.class), mock(SamzaContainerContext.class), + offsetManager, null, null, sspSet, new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>())); + } + + ExecutorService callbackExecutor; + void triggerCallback(final TestTask task, final TaskCallback callback, final boolean success) { + callbackExecutor.submit(new Runnable() { + @Override + public void run() { + if (task.code != null) { + task.code.run(callback); + } + + task.completed.incrementAndGet(); + + if (success) { + callback.complete(); + } else { + callback.failure(new Exception("process failure")); + } + } + }); + } + + interface TestCode { + void run(TaskCallback callback); + } + + class TestTask implements AsyncStreamTask, WindowableTask { + boolean shutdown = false; + boolean commit = false; + boolean success; + int processed = 0; + volatile int windowCount = 0; + + AtomicInteger completed = new AtomicInteger(0); + TestCode code = null; + + TestTask(boolean success, boolean commit, boolean shutdown) { + this.success = success; + this.shutdown = shutdown; + this.commit = commit; + } + + @Override + public void processAsync(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator, + TaskCallback callback) { + + if (maxMessagesInFlight == 1) { + assertEquals(processed, completed.get()); + } + + processed++; + + if (commit) { + coordinator.commit(commitRequest); + } + + if (shutdown) { + coordinator.shutdown(shutdownRequest); + } + triggerCallback(this, callback, success); + } + + @Override + public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception { + windowCount++; + + if (shutdown && windowCount == 4) { + coordinator.shutdown(shutdownRequest); + } + } + } + + @Before + public void setup() { + executor = null; + consumerMultiplexer = mock(SystemConsumers.class); + windowMs = -1; + commitMs = -1; + maxMessagesInFlight = 1; + containerMetrics = new SamzaContainerMetrics("container", new MetricsRegistryMap()); + callbackExecutor = Executors.newFixedThreadPool(2); + offsetManager = mock(OffsetManager.class); + shutdownRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER; + + when(consumerMultiplexer.pollIntervalMs()).thenReturn(1000000); + + tasks = new HashMap<>(); + task0 = new TestTask(true, true, false); + task1 = new TestTask(true, false, true); + t0 = createTaskInstance(task0, taskName0, ssp0); + t1 = createTaskInstance(task1, taskName1, ssp1); + tasks.put(taskName0, t0); + tasks.put(taskName1, t1); + } + + @Test + public void testProcessMultipleTasks() throws Exception { + AsyncRunLoop runLoop = createRunLoop(); + when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null); + runLoop.run(); + + callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS); + + assertEquals(1, task0.processed); + assertEquals(1, task0.completed.get()); + assertEquals(1, task1.processed); + assertEquals(1, task1.completed.get()); + assertEquals(2L, containerMetrics.envelopes().getCount()); + assertEquals(2L, containerMetrics.processes().getCount()); + } + + @Test + public void testProcessInOrder() throws Exception { + AsyncRunLoop runLoop = createRunLoop(); + when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null); + runLoop.run(); + + callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS); + + assertEquals(2, task0.processed); + assertEquals(2, task0.completed.get()); + assertEquals(1, task1.processed); + assertEquals(1, task1.completed.get()); + assertEquals(3L, containerMetrics.envelopes().getCount()); + assertEquals(3L, containerMetrics.processes().getCount()); + } + + @Test + public void testProcessOutOfOrder() throws Exception { + maxMessagesInFlight = 2; + + final CountDownLatch latch = new CountDownLatch(1); + task0.code = new TestCode() { + @Override + public void run(TaskCallback callback) { + IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).envelope; + if (envelope == envelope0) { + // process first message will wait till the second one is processed + try { + latch.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } else { + // second envelope complete first + assertEquals(0, task0.completed.get()); + latch.countDown(); + } + } + }; + + AsyncRunLoop runLoop = createRunLoop(); + when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null); + runLoop.run(); + + callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS); + + assertEquals(2, task0.processed); + assertEquals(2, task0.completed.get()); + assertEquals(1, task1.processed); + assertEquals(1, task1.completed.get()); + assertEquals(3L, containerMetrics.envelopes().getCount()); + assertEquals(3L, containerMetrics.processes().getCount()); + } + + @Test + public void testWindow() throws Exception { + windowMs = 1; + + AsyncRunLoop runLoop = createRunLoop(); + when(consumerMultiplexer.choose(false)).thenReturn(null); + runLoop.run(); + + callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS); + + assertEquals(4, task1.windowCount); + } + + @Test + public void testCommitSingleTask() throws Exception { + commitRequest = TaskCoordinator.RequestScope.CURRENT_TASK; + + AsyncRunLoop runLoop = createRunLoop(); + //have a null message in between to make sure task0 finishes processing and invoke the commit + when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(null).thenReturn(envelope1).thenReturn(null); + runLoop.run(); + + callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS); + + verify(offsetManager).checkpoint(taskName0); + verify(offsetManager, never()).checkpoint(taskName1); + } + + @Test + public void testCommitAllTasks() throws Exception { + commitRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER; + + AsyncRunLoop runLoop = createRunLoop(); + //have a null message in between to make sure task0 finishes processing and invoke the commit + when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(null).thenReturn(envelope1).thenReturn(null); + runLoop.run(); + + callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS); + + verify(offsetManager).checkpoint(taskName0); + verify(offsetManager).checkpoint(taskName1); + } + + @Test + public void testShutdownOnConsensus() throws Exception { + shutdownRequest = TaskCoordinator.RequestScope.CURRENT_TASK; + + tasks = new HashMap<>(); + task0 = new TestTask(true, true, true); + task1 = new TestTask(true, false, true); + t0 = createTaskInstance(task0, taskName0, ssp0); + t1 = createTaskInstance(task1, taskName1, ssp1); + tasks.put(taskName0, t0); + tasks.put(taskName1, t1); + + AsyncRunLoop runLoop = createRunLoop(); + // consensus is reached after envelope1 is processed. + when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null); + runLoop.run(); + + callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS); + + assertEquals(1, task0.processed); + assertEquals(1, task0.completed.get()); + assertEquals(1, task1.processed); + assertEquals(1, task1.completed.get()); + assertEquals(2L, containerMetrics.envelopes().getCount()); + assertEquals(2L, containerMetrics.processes().getCount()); + } +} http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/java/org/apache/samza/task/TestAsyncStreamAdapter.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/task/TestAsyncStreamAdapter.java b/samza-core/src/test/java/org/apache/samza/task/TestAsyncStreamAdapter.java new file mode 100644 index 0000000..99e1e18 --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/task/TestAsyncStreamAdapter.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.samza.task; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import org.apache.samza.config.Config; +import org.apache.samza.system.IncomingMessageEnvelope; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + + +public class TestAsyncStreamAdapter { + TestStreamTask task; + AsyncStreamTaskAdapter taskAdaptor; + Exception e; + IncomingMessageEnvelope envelope; + + class TestCallbackListener implements TaskCallbackListener { + boolean callbackComplete = false; + boolean callbackFailure = false; + + @Override + public void onComplete(TaskCallback callback) { + callbackComplete = true; + } + + @Override + public void onFailure(TaskCallback callback, Throwable t) { + callbackFailure = true; + } + } + + class TestStreamTask implements StreamTask, InitableTask, ClosableTask, WindowableTask { + boolean inited = false; + boolean closed = false; + boolean processed = false; + boolean windowed = false; + + @Override + public void close() throws Exception { + closed = true; + } + + @Override + public void init(Config config, TaskContext context) throws Exception { + inited = true; + } + + @Override + public void process(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator) throws Exception { + processed = true; + if (e != null) { + throw e; + } + } + + @Override + public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception { + windowed = true; + } + } + + @Before + public void setup() { + task = new TestStreamTask(); + e = null; + envelope = mock(IncomingMessageEnvelope.class); + } + + @Test + public void testAdapterWithoutThreadPool() throws Exception { + taskAdaptor = new AsyncStreamTaskAdapter(task, null); + TestCallbackListener listener = new TestCallbackListener(); + TaskCallback callback = new TaskCallbackImpl(listener, null, envelope, null, 0L); + + taskAdaptor.init(null, null); + assertTrue(task.inited); + + taskAdaptor.processAsync(null, null, null, callback); + assertTrue(task.processed); + assertTrue(listener.callbackComplete); + + e = new Exception("dummy exception"); + taskAdaptor.processAsync(null, null, null, callback); + assertTrue(listener.callbackFailure); + + taskAdaptor.window(null, null); + assertTrue(task.windowed); + + taskAdaptor.close(); + assertTrue(task.closed); + } + + @Test + public void testAdapterWithThreadPool() throws Exception { + TestCallbackListener listener1 = new TestCallbackListener(); + TaskCallback callback1 = new TaskCallbackImpl(listener1, null, envelope, null, 0L); + + TestCallbackListener listener2 = new TestCallbackListener(); + TaskCallback callback2 = new TaskCallbackImpl(listener2, null, envelope, null, 1L); + + ExecutorService executor = Executors.newFixedThreadPool(2); + taskAdaptor = new AsyncStreamTaskAdapter(task, executor); + taskAdaptor.processAsync(null, null, null, callback1); + taskAdaptor.processAsync(null, null, null, callback2); + + executor.awaitTermination(1, TimeUnit.SECONDS); + assertTrue(listener1.callbackComplete); + assertTrue(listener2.callbackComplete); + + e = new Exception("dummy exception"); + taskAdaptor.processAsync(null, null, null, callback1); + taskAdaptor.processAsync(null, null, null, callback2); + + executor.awaitTermination(1, TimeUnit.SECONDS); + assertTrue(listener1.callbackFailure); + assertTrue(listener2.callbackFailure); + } +} http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/java/org/apache/samza/task/TestCoordinatorRequests.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/task/TestCoordinatorRequests.java b/samza-core/src/test/java/org/apache/samza/task/TestCoordinatorRequests.java new file mode 100644 index 0000000..d9c68d7 --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/task/TestCoordinatorRequests.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.samza.task; + +import java.util.HashSet; +import java.util.Set; +import org.apache.samza.container.TaskName; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class TestCoordinatorRequests { + CoordinatorRequests coordinatorRequests; + TaskName taskA = new TaskName("a"); + TaskName taskB = new TaskName("b"); + TaskName taskC = new TaskName("c"); + + + @Before + public void setup() { + Set<TaskName> taskNames = new HashSet<>(); + taskNames.add(taskA); + taskNames.add(taskB); + taskNames.add(taskC); + + coordinatorRequests = new CoordinatorRequests(taskNames); + } + + @Test + public void testUpdateCommit() { + ReadableCoordinator coordinator = new ReadableCoordinator(taskA); + coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK); + coordinatorRequests.update(coordinator); + assertTrue(coordinatorRequests.commitRequests().contains(taskA)); + + coordinator = new ReadableCoordinator(taskC); + coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK); + coordinatorRequests.update(coordinator); + assertTrue(coordinatorRequests.commitRequests().contains(taskC)); + assertFalse(coordinatorRequests.commitRequests().contains(taskB)); + assertTrue(coordinatorRequests.commitRequests().size() == 2); + + coordinator.commit(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER); + coordinatorRequests.update(coordinator); + assertTrue(coordinatorRequests.commitRequests().contains(taskB)); + assertTrue(coordinatorRequests.commitRequests().size() == 3); + } + + @Test + public void testUpdateShutdownOnConsensus() { + ReadableCoordinator coordinator = new ReadableCoordinator(taskA); + coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK); + coordinatorRequests.update(coordinator); + assertFalse(coordinatorRequests.shouldShutdownNow()); + + coordinator = new ReadableCoordinator(taskB); + coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK); + coordinatorRequests.update(coordinator); + assertFalse(coordinatorRequests.shouldShutdownNow()); + + coordinator = new ReadableCoordinator(taskC); + coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK); + coordinatorRequests.update(coordinator); + assertTrue(coordinatorRequests.shouldShutdownNow()); + } + + @Test + public void testUpdateShutdownNow() { + ReadableCoordinator coordinator = new ReadableCoordinator(taskA); + coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER); + coordinatorRequests.update(coordinator); + assertTrue(coordinatorRequests.shouldShutdownNow()); + } +} http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackImpl.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackImpl.java b/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackImpl.java new file mode 100644 index 0000000..f1dbf35 --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackImpl.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.samza.task; + +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.samza.system.IncomingMessageEnvelope; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + + +public class TestTaskCallbackImpl { + + TaskCallbackListener listener = null; + AtomicInteger completeCount; + AtomicInteger failureCount; + TaskCallback callback = null; + Throwable throwable = null; + + @Before + public void setup() { + completeCount = new AtomicInteger(0); + failureCount = new AtomicInteger(0); + throwable = null; + + listener = new TaskCallbackListener() { + + @Override + public void onComplete(TaskCallback callback) { + completeCount.incrementAndGet(); + } + + @Override + public void onFailure(TaskCallback callback, Throwable t) { + throwable = t; + failureCount.incrementAndGet(); + } + }; + + callback = new TaskCallbackImpl(listener, null, mock(IncomingMessageEnvelope.class), null, 0); + } + + @Test + public void testComplete() { + callback.complete(); + assertEquals(1L, completeCount.get()); + assertEquals(0L, failureCount.get()); + } + + @Test + public void testFailure() { + callback.failure(new Exception("dummy exception")); + assertEquals(0L, completeCount.get()); + assertEquals(1L, failureCount.get()); + } + + @Test + public void testCallbackMultipleComplete() { + callback.complete(); + assertEquals(1L, completeCount.get()); + + callback.complete(); + assertEquals(1L, failureCount.get()); + assertTrue(throwable instanceof IllegalStateException); + } + + @Test + public void testCallbackFailureAfterComplete() { + callback.complete(); + assertEquals(1L, completeCount.get()); + + callback.failure(new Exception("dummy exception")); + assertEquals(1L, failureCount.get()); + assertTrue(throwable instanceof IllegalStateException); + } + + + @Test + public void testMultithreadedCallbacks() throws Exception { + final CyclicBarrier barrier = new CyclicBarrier(2); + ExecutorService executor = Executors.newFixedThreadPool(2); + + for (int i = 0; i < 2; i++) { + executor.submit(new Runnable() { + @Override + public void run() { + try { + barrier.await(); + callback.complete(); + } catch (Exception e) { + e.printStackTrace(); + } + } + }); + } + executor.awaitTermination(1, TimeUnit.SECONDS); + assertEquals(1L, completeCount.get()); + assertEquals(1L, failureCount.get()); + assertTrue(throwable instanceof IllegalStateException); + } +} http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java b/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java new file mode 100644 index 0000000..d7110f3 --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.samza.task; + +import org.apache.samza.Partition; +import org.apache.samza.container.TaskInstanceMetrics; +import org.apache.samza.container.TaskName; +import org.apache.samza.metrics.MetricsRegistryMap; +import org.apache.samza.system.IncomingMessageEnvelope; +import org.apache.samza.system.SystemStreamPartition; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + + +public class TestTaskCallbackManager { + TaskCallbackManager callbackManager = null; + TaskCallbackListener listener = null; + + @Before + public void setup() { + TaskInstanceMetrics metrics = new TaskInstanceMetrics("Partition 0", new MetricsRegistryMap()); + listener = new TaskCallbackListener() { + @Override + public void onComplete(TaskCallback callback) { + } + @Override + public void onFailure(TaskCallback callback, Throwable t) { + } + }; + callbackManager = new TaskCallbackManager(listener, metrics, null, -1); + + } + + @Test + public void testCreateCallback() { + TaskCallbackImpl callback = callbackManager.createCallback(new TaskName("Partition 0"), null, null); + assertTrue(callback.matchSeqNum(0)); + + callback = callbackManager.createCallback(new TaskName("Partition 0"), null, null); + assertTrue(callback.matchSeqNum(1)); + } + + @Test + public void testUpdateCallbackInOrder() { + TaskName taskName = new TaskName("Partition 0"); + SystemStreamPartition ssp = new SystemStreamPartition("kafka", "topic", new Partition(0)); + ReadableCoordinator coordinator = new ReadableCoordinator(taskName); + + IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp, "0", null, null); + TaskCallbackImpl callback0 = new TaskCallbackImpl(listener, taskName, envelope0, coordinator, 0); + TaskCallbackImpl callbackToCommit = callbackManager.updateCallback(callback0, true); + assertTrue(callbackToCommit.matchSeqNum(0)); + assertEquals(ssp, callbackToCommit.envelope.getSystemStreamPartition()); + assertEquals("0", callbackToCommit.envelope.getOffset()); + + IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp, "1", null, null); + TaskCallbackImpl callback1 = new TaskCallbackImpl(listener, taskName, envelope1, coordinator, 1); + callbackToCommit = callbackManager.updateCallback(callback1, true); + assertTrue(callbackToCommit.matchSeqNum(1)); + assertEquals(ssp, callbackToCommit.envelope.getSystemStreamPartition()); + assertEquals("1", callbackToCommit.envelope.getOffset()); + } + + @Test + public void testUpdateCallbackOutofOrder() { + TaskName taskName = new TaskName("Partition 0"); + SystemStreamPartition ssp = new SystemStreamPartition("kafka", "topic", new Partition(0)); + ReadableCoordinator coordinator = new ReadableCoordinator(taskName); + + // simulate out of order + IncomingMessageEnvelope envelope2 = new IncomingMessageEnvelope(ssp, "2", null, null); + TaskCallbackImpl callback2 = new TaskCallbackImpl(listener, taskName, envelope2, coordinator, 2); + TaskCallbackImpl callbackToCommit = callbackManager.updateCallback(callback2, true); + assertNull(callbackToCommit); + + IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp, "1", null, null); + TaskCallbackImpl callback1 = new TaskCallbackImpl(listener, taskName, envelope1, coordinator, 1); + callbackToCommit = callbackManager.updateCallback(callback1, true); + assertNull(callbackToCommit); + + IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp, "0", null, null); + TaskCallbackImpl callback0 = new TaskCallbackImpl(listener, taskName, envelope0, coordinator, 0); + callbackToCommit = callbackManager.updateCallback(callback0, true); + assertTrue(callbackToCommit.matchSeqNum(2)); + assertEquals(ssp, callbackToCommit.envelope.getSystemStreamPartition()); + assertEquals("2", callbackToCommit.envelope.getOffset()); + } + + @Test + public void testUpdateCallbackWithCoordinatorRequests() { + TaskName taskName = new TaskName("Partition 0"); + SystemStreamPartition ssp = new SystemStreamPartition("kafka", "topic", new Partition(0)); + + + // simulate out of order + IncomingMessageEnvelope envelope2 = new IncomingMessageEnvelope(ssp, "2", null, null); + ReadableCoordinator coordinator2 = new ReadableCoordinator(taskName); + coordinator2.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER); + TaskCallbackImpl callback2 = new TaskCallbackImpl(listener, taskName, envelope2, coordinator2, 2); + TaskCallbackImpl callbackToCommit = callbackManager.updateCallback(callback2, true); + assertNull(callbackToCommit); + + IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp, "1", null, null); + ReadableCoordinator coordinator1 = new ReadableCoordinator(taskName); + coordinator1.commit(TaskCoordinator.RequestScope.CURRENT_TASK); + TaskCallbackImpl callback1 = new TaskCallbackImpl(listener, taskName, envelope1, coordinator1, 1); + callbackToCommit = callbackManager.updateCallback(callback1, true); + assertNull(callbackToCommit); + + IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp, "0", null, null); + ReadableCoordinator coordinator = new ReadableCoordinator(taskName); + TaskCallbackImpl callback0 = new TaskCallbackImpl(listener, taskName, envelope0, coordinator, 0); + callbackToCommit = callbackManager.updateCallback(callback0, true); + assertTrue(callbackToCommit.matchSeqNum(1)); + assertEquals(ssp, callbackToCommit.envelope.getSystemStreamPartition()); + assertEquals("1", callbackToCommit.envelope.getOffset()); + assertTrue(callbackToCommit.coordinator.requestedShutdownNow()); + } + +} http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala b/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala index e280daa..aa1a8d6 100644 --- a/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala +++ b/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala @@ -20,22 +20,26 @@ package org.apache.samza.container -import org.apache.samza.metrics.{Timer, SlidingTimeWindowReservoir, MetricsRegistryMap} +import org.apache.samza.Partition +import org.apache.samza.metrics.MetricsRegistryMap +import org.apache.samza.metrics.SlidingTimeWindowReservoir +import org.apache.samza.metrics.Timer +import org.apache.samza.system.IncomingMessageEnvelope +import org.apache.samza.system.SystemConsumers +import org.apache.samza.system.SystemStreamPartition +import org.apache.samza.task.TaskCoordinator.RequestScope +import org.apache.samza.task.ReadableCoordinator +import org.apache.samza.task.StreamTask import org.apache.samza.util.Clock -import org.junit.Test import org.junit.Assert._ +import org.junit.Test import org.mockito.Matchers import org.mockito.Mockito._ -import org.mockito.internal.util.reflection.Whitebox import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.junit.AssertionsForJUnit -import org.scalatest.{Matchers => ScalaTestMatchers} import org.scalatest.mock.MockitoSugar -import org.apache.samza.Partition -import org.apache.samza.system.{ IncomingMessageEnvelope, SystemConsumers, SystemStreamPartition } -import org.apache.samza.task.ReadableCoordinator -import org.apache.samza.task.TaskCoordinator.RequestScope +import org.scalatest.{Matchers => ScalaTestMatchers} class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMatchers { class StopRunLoop extends RuntimeException @@ -49,12 +53,12 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat val envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0") val envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1") - def getMockTaskInstances: Map[TaskName, TaskInstance] = { - val ti0 = mock[TaskInstance] + def getMockTaskInstances: Map[TaskName, TaskInstance[StreamTask]] = { + val ti0 = mock[TaskInstance[StreamTask]] when(ti0.systemStreamPartitions).thenReturn(Set(ssp0)) when(ti0.taskName).thenReturn(taskName0) - val ti1 = mock[TaskInstance] + val ti1 = mock[TaskInstance[StreamTask]] when(ti1.systemStreamPartitions).thenReturn(Set(ssp1)) when(ti1.taskName).thenReturn(taskName1) @@ -67,10 +71,10 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat val consumers = mock[SystemConsumers] val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics) - when(consumers.choose).thenReturn(envelope0).thenReturn(envelope1).thenThrow(new StopRunLoop) + when(consumers.choose()).thenReturn(envelope0).thenReturn(envelope1).thenThrow(new StopRunLoop) intercept[StopRunLoop] { runLoop.run } - verify(taskInstances(taskName0)).process(Matchers.eq(envelope0), anyObject) - verify(taskInstances(taskName1)).process(Matchers.eq(envelope1), anyObject) + verify(taskInstances(taskName0)).process(Matchers.eq(envelope0), anyObject, anyObject) + verify(taskInstances(taskName1)).process(Matchers.eq(envelope1), anyObject, anyObject) runLoop.metrics.envelopes.getCount should equal(2L) runLoop.metrics.nullEnvelopes.getCount should equal(0L) } @@ -80,7 +84,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat val consumers = mock[SystemConsumers] val map = getMockTaskInstances - taskName1 // This test only needs p0 val runLoop = new RunLoop(map, consumers, new SamzaContainerMetrics) - when(consumers.choose).thenReturn(null).thenReturn(null).thenThrow(new StopRunLoop) + when(consumers.choose()).thenReturn(null).thenReturn(null).thenThrow(new StopRunLoop) intercept[StopRunLoop] { runLoop.run } runLoop.metrics.envelopes.getCount should equal(0L) runLoop.metrics.nullEnvelopes.getCount should equal(2L) @@ -90,7 +94,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat def testWindowAndCommitAreCalledRegularly { var now = 1400000000000L val consumers = mock[SystemConsumers] - when(consumers.choose).thenReturn(envelope0) + when(consumers.choose()).thenReturn(envelope0) val runLoop = new RunLoop( taskInstances = getMockTaskInstances, @@ -118,7 +122,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat val consumers = mock[SystemConsumers] val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1) - when(consumers.choose).thenReturn(envelope0).thenReturn(envelope1).thenThrow(new StopRunLoop) + when(consumers.choose()).thenReturn(envelope0).thenReturn(envelope1).thenThrow(new StopRunLoop) stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.commit(RequestScope.CURRENT_TASK)) intercept[StopRunLoop] { runLoop.run } @@ -132,7 +136,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat val consumers = mock[SystemConsumers] val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1) - when(consumers.choose).thenReturn(envelope0).thenThrow(new StopRunLoop) + when(consumers.choose()).thenReturn(envelope0).thenThrow(new StopRunLoop) stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.commit(RequestScope.ALL_TASKS_IN_CONTAINER)) intercept[StopRunLoop] { runLoop.run } @@ -146,13 +150,13 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat val consumers = mock[SystemConsumers] val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1) - when(consumers.choose).thenReturn(envelope0).thenReturn(envelope0).thenReturn(envelope1) + when(consumers.choose()).thenReturn(envelope0).thenReturn(envelope0).thenReturn(envelope1) stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.shutdown(RequestScope.CURRENT_TASK)) stubProcess(taskInstances(taskName1), (envelope, coordinator) => coordinator.shutdown(RequestScope.CURRENT_TASK)) runLoop.run - verify(taskInstances(taskName0), times(2)).process(Matchers.eq(envelope0), anyObject) - verify(taskInstances(taskName1), times(1)).process(Matchers.eq(envelope1), anyObject) + verify(taskInstances(taskName0), times(2)).process(Matchers.eq(envelope0), anyObject, anyObject) + verify(taskInstances(taskName1), times(1)).process(Matchers.eq(envelope1), anyObject, anyObject) } @Test @@ -161,19 +165,19 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat val consumers = mock[SystemConsumers] val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1) - when(consumers.choose).thenReturn(envelope0).thenReturn(envelope1) + when(consumers.choose()).thenReturn(envelope0).thenReturn(envelope1) stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.shutdown(RequestScope.ALL_TASKS_IN_CONTAINER)) runLoop.run - verify(taskInstances(taskName0), times(1)).process(anyObject, anyObject) - verify(taskInstances(taskName1), times(0)).process(anyObject, anyObject) + verify(taskInstances(taskName0), times(1)).process(anyObject, anyObject, anyObject) + verify(taskInstances(taskName1), times(0)).process(anyObject, anyObject, anyObject) } def anyObject[T] = Matchers.anyObject.asInstanceOf[T] // Stub out TaskInstance.process. Mockito really doesn't make this easy. :( - def stubProcess(taskInstance: TaskInstance, process: (IncomingMessageEnvelope, ReadableCoordinator) => Unit) { - when(taskInstance.process(anyObject, anyObject)).thenAnswer(new Answer[Unit]() { + def stubProcess(taskInstance: TaskInstance[StreamTask], process: (IncomingMessageEnvelope, ReadableCoordinator) => Unit) { + when(taskInstance.process(anyObject, anyObject, anyObject)).thenAnswer(new Answer[Unit]() { override def answer(invocation: InvocationOnMock) { val envelope = invocation.getArguments()(0).asInstanceOf[IncomingMessageEnvelope] val coordinator = invocation.getArguments()(1).asInstanceOf[ReadableCoordinator] @@ -186,7 +190,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat def testUpdateTimerCorrectly { var now = 0L val consumers = mock[SystemConsumers] - when(consumers.choose).thenReturn(envelope0) + when(consumers.choose()).thenReturn(envelope0) val clock = new Clock { var c = 0L def currentTimeMillis: Long = { @@ -263,9 +267,9 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat @Test def testGetSystemStreamPartitionToTaskInstancesMapping { - val ti0 = mock[TaskInstance] - val ti1 = mock[TaskInstance] - val ti2 = mock[TaskInstance] + val ti0 = mock[TaskInstance[StreamTask]] + val ti1 = mock[TaskInstance[StreamTask]] + val ti2 = mock[TaskInstance[StreamTask]] when(ti0.systemStreamPartitions).thenReturn(Set(ssp0)) when(ti1.systemStreamPartitions).thenReturn(Set(ssp1)) when(ti2.systemStreamPartitions).thenReturn(Set(ssp1)) http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala index 1358fdd..cff6b96 100644 --- a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala +++ b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala @@ -180,7 +180,7 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar { new SerdeManager) val collector = new TaskInstanceCollector(producerMultiplexer) val containerContext = new SamzaContainerContext(0, config, Set[TaskName](taskName)) - val taskInstance: TaskInstance = new TaskInstance( + val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask]( task, taskName, config, @@ -261,7 +261,7 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar { } }) - val taskInstance: TaskInstance = new TaskInstance( + val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask]( task, taskName, config, @@ -314,4 +314,4 @@ class MockJobServlet(exceptionLimit: Int, jobModelRef: AtomicReference[JobModel] jobModel } } -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala index 5457f0e..3c83529 100644 --- a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala +++ b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala @@ -71,7 +71,7 @@ class TestTaskInstance { val taskName = new TaskName("taskName") val collector = new TaskInstanceCollector(producerMultiplexer) val containerContext = new SamzaContainerContext(0, config, Set[TaskName](taskName)) - val taskInstance: TaskInstance = new TaskInstance( + val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask]( task, taskName, config, @@ -169,7 +169,7 @@ class TestTaskInstance { val registry = new MetricsRegistryMap val taskMetrics = new TaskInstanceMetrics(registry = registry) - val taskInstance: TaskInstance = new TaskInstance( + val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask]( task, taskName, config, @@ -226,7 +226,7 @@ class TestTaskInstance { val registry = new MetricsRegistryMap val taskMetrics = new TaskInstanceMetrics(registry = registry) - val taskInstance: TaskInstance = new TaskInstance( + val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask]( task, taskName, config, http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala ---------------------------------------------------------------------- diff --git a/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala b/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala index 09da62e..db2249b 100644 --- a/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala +++ b/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala @@ -54,14 +54,14 @@ class TestSystemConsumers { consumer.setResponseSizes(numEnvelopes) // Choose to trigger a refresh with data. - assertNull(consumers.choose) + assertNull(consumers.choose()) // 2: First on start, second on choose. assertEquals(2, consumer.polls) assertEquals(2, consumer.lastPoll.size) assertTrue(consumer.lastPoll.contains(systemStreamPartition0)) assertTrue(consumer.lastPoll.contains(systemStreamPartition1)) - assertEquals(envelope, consumers.choose) - assertEquals(envelope, consumers.choose) + assertEquals(envelope, consumers.choose()) + assertEquals(envelope, consumers.choose()) // We aren't polling because we're getting non-null envelopes. assertEquals(2, consumer.polls) @@ -69,7 +69,7 @@ class TestSystemConsumers { // messages. now = SystemConsumers.DEFAULT_POLL_INTERVAL_MS - assertEquals(envelope, consumers.choose) + assertEquals(envelope, consumers.choose()) // We polled even though there are still 997 messages in the unprocessed // message buffer. @@ -82,11 +82,11 @@ class TestSystemConsumers { // Now drain all messages for SSP0. There should be exactly 997 messages, // since we have chosen 3 already, and we started with 1000. (0 until (numEnvelopes - 3)).foreach { i => - assertEquals(envelope, consumers.choose) + assertEquals(envelope, consumers.choose()) } // Nothing left. Should trigger a poll here. - assertNull(consumers.choose) + assertNull(consumers.choose()) assertEquals(4, consumer.polls) assertEquals(2, consumer.lastPoll.size) @@ -117,31 +117,31 @@ class TestSystemConsumers { consumer.setResponseSizes(1) // Choose to trigger a refresh with data. - assertNull(consumers.choose) + assertNull(consumers.choose()) // Choose should have triggered a second poll, since no messages are available. assertEquals(2, consumer.polls) // Choose a few times. This time there is no data. - assertEquals(envelope, consumers.choose) - assertNull(consumers.choose) - assertNull(consumers.choose) + assertEquals(envelope, consumers.choose()) + assertNull(consumers.choose()) + assertNull(consumers.choose()) // Return more than one message this time. consumer.setResponseSizes(2) // Choose to trigger a refresh with data. - assertNull(consumers.choose) + assertNull(consumers.choose()) // Increase clock interval. now = SystemConsumers.DEFAULT_POLL_INTERVAL_MS // We get two messages now. - assertEquals(envelope, consumers.choose) + assertEquals(envelope, consumers.choose()) // Should not poll even though clock interval increases past interval threshold. assertEquals(2, consumer.polls) - assertEquals(envelope, consumers.choose) - assertNull(consumers.choose) + assertEquals(envelope, consumers.choose()) + assertNull(consumers.choose()) } @Test @@ -238,7 +238,7 @@ class TestSystemConsumers { var caughtRightException = false try { - consumers.choose + consumers.choose() } catch { case e: SystemConsumersException => caughtRightException = true case _: Throwable => caughtRightException = false @@ -256,13 +256,13 @@ class TestSystemConsumers { var notThrowException = true; try { - consumers2.choose + consumers2.choose() } catch { case e: Throwable => notThrowException = false } assertTrue("it should not throw any exception", notThrowException) - var msgEnvelope = Some(consumers2.choose) + var msgEnvelope = Some(consumers2.choose()) assertTrue("Consumer did not succeed in receiving the second message after Serde exception in choose", msgEnvelope.get != null) consumers2.stop @@ -279,7 +279,7 @@ class TestSystemConsumers { assertTrue("SystemConsumer start should not throw any Serde exception", notThrowException) msgEnvelope = null - msgEnvelope = Some(consumers2.choose) + msgEnvelope = Some(consumers2.choose()) assertTrue("Consumer did not succeed in receiving the second message after Serde exception in poll", msgEnvelope.get != null) consumers2.stop http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemProducer.scala ---------------------------------------------------------------------- diff --git a/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemProducer.scala b/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemProducer.scala index 1f4b5c4..24bc8b5 100644 --- a/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemProducer.scala +++ b/samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemProducer.scala @@ -36,6 +36,7 @@ class HdfsSystemProducer( val clock: () => Long = () => System.currentTimeMillis) extends SystemProducer with Logging with TimerUtils { val dfs = FileSystem.get(new Configuration(true)) val writers: MMap[String, HdfsWriter[_]] = MMap.empty[String, HdfsWriter[_]] + private val lock = new Object //synchronization lock for thread safe access def start(): Unit = { info("entering HdfsSystemProducer.start() call for system: " + systemName + ", client: " + clientId) @@ -43,52 +44,65 @@ class HdfsSystemProducer( def stop(): Unit = { info("entering HdfsSystemProducer.stop() for system: " + systemName + ", client: " + clientId) - writers.values.map { _.close } - dfs.close + + lock.synchronized { + writers.values.map(_.close) + dfs.close + } } def register(source: String): Unit = { info("entering HdfsSystemProducer.register(" + source + ") " + "call for system: " + systemName + ", client: " + clientId) - writers += (source -> HdfsWriter.getInstance(dfs, systemName, config)) + + lock.synchronized { + writers += (source -> HdfsWriter.getInstance(dfs, systemName, config)) + } } def flush(source: String): Unit = { debug("entering HdfsSystemProducer.flush(" + source + ") " + "call for system: " + systemName + ", client: " + clientId) - try { - metrics.flushes.inc - updateTimer(metrics.flushMs) { writers.get(source).head.flush } - metrics.flushSuccess.inc - } catch { - case e: Exception => { - metrics.flushFailed.inc - warn("Exception thrown while client " + clientId + " flushed HDFS out stream, msg: " + e.getMessage) - debug("Detailed message from exception thrown by client " + clientId + " in HDFS flush: ", e) - writers.get(source).head.close - throw e + + metrics.flushes.inc + lock.synchronized { + try { + updateTimer(metrics.flushMs) { + writers.get(source).head.flush + } + } catch { + case e: Exception => { + metrics.flushFailed.inc + warn("Exception thrown while client " + clientId + " flushed HDFS out stream, msg: " + e.getMessage) + debug("Detailed message from exception thrown by client " + clientId + " in HDFS flush: ", e) + writers.get(source).head.close + throw e + } } } + metrics.flushSuccess.inc } def send(source: String, ome: OutgoingMessageEnvelope) = { debug("entering HdfsSystemProducer.send(source = " + source + ", envelope) " + "call for system: " + systemName + ", client: " + clientId) + metrics.sends.inc - try { - updateTimer(metrics.sendMs) { - writers.get(source).head.write(ome) - } - metrics.sendSuccess.inc - } catch { - case e: Exception => { - metrics.sendFailed.inc - warn("Exception thrown while client " + clientId + " wrote to HDFS, msg: " + e.getMessage) - debug("Detailed message from exception thrown by client " + clientId + " in HDFS write: ", e) - writers.get(source).head.close - throw e + lock.synchronized { + try { + updateTimer(metrics.sendMs) { + writers.get(source).head.write(ome) + } + } catch { + case e: Exception => { + metrics.sendFailed.inc + warn("Exception thrown while client " + clientId + " wrote to HDFS, msg: " + e.getMessage) + debug("Detailed message from exception thrown by client " + clientId + " in HDFS write: ", e) + writers.get(source).head.close + throw e + } } } + metrics.sendSuccess.inc } - -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kafka/src/main/scala/org/apache/samza/migration/KafkaCheckpointMigration.scala ---------------------------------------------------------------------- diff --git a/samza-kafka/src/main/scala/org/apache/samza/migration/KafkaCheckpointMigration.scala b/samza-kafka/src/main/scala/org/apache/samza/migration/KafkaCheckpointMigration.scala index 5e8cc65..5d2641a 100644 --- a/samza-kafka/src/main/scala/org/apache/samza/migration/KafkaCheckpointMigration.scala +++ b/samza-kafka/src/main/scala/org/apache/samza/migration/KafkaCheckpointMigration.scala @@ -140,6 +140,7 @@ class KafkaCheckpointMigration extends MigrationPlan with Logging { def migrationCompletionMark(coordinatorSystemProducer: CoordinatorStreamSystemProducer) = { info("Marking completion of migration %s" format migrationKey) val message = new SetMigrationMetaMessage(source, migrationKey, migrationVal) + coordinatorSystemProducer.register(source) coordinatorSystemProducer.start() coordinatorSystemProducer.send(message) coordinatorSystemProducer.stop()
