This is an automated email from the ASF dual-hosted git repository.

rsivaram pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 710aa67  KAFKA-6242: Dynamic resize of various broker thread pools 
(#4471)
710aa67 is described below

commit 710aa678b7a6409296fb3c852a47c1876d8fa8e9
Author: Rajini Sivaram <[email protected]>
AuthorDate: Tue Jan 30 09:29:27 2018 -0800

    KAFKA-6242: Dynamic resize of various broker thread pools (#4471)
    
    Dynamic resize of broker thread pools as described in KIP-226:
      - num.network.threads
      - num.io.threads
      - num.replica.fetchers
      - num.recovery.threads.per.data.dir
      - background.threads
    
    Reviewers: Jason Gustafson <[email protected]>
---
 core/src/main/scala/kafka/log/LogManager.scala     |  14 ++-
 .../main/scala/kafka/network/RequestChannel.scala  |  39 +++++--
 .../main/scala/kafka/network/SocketServer.scala    | 130 +++++++++++++--------
 .../kafka/server/AbstractFetcherManager.scala      |  50 ++++++--
 .../scala/kafka/server/DynamicBrokerConfig.scala   |  59 ++++++++++
 core/src/main/scala/kafka/server/KafkaConfig.scala |  10 +-
 .../scala/kafka/server/KafkaRequestHandler.scala   |  55 ++++++---
 .../main/scala/kafka/utils/KafkaScheduler.scala    |   4 +
 .../server/DynamicBrokerReconfigurationTest.scala  |  85 +++++++++++++-
 .../test/scala/unit/kafka/utils/TestUtils.scala    |   2 +-
 10 files changed, 353 insertions(+), 95 deletions(-)

diff --git a/core/src/main/scala/kafka/log/LogManager.scala 
b/core/src/main/scala/kafka/log/LogManager.scala
index adf1b9c..9ae93aa 100755
--- a/core/src/main/scala/kafka/log/LogManager.scala
+++ b/core/src/main/scala/kafka/log/LogManager.scala
@@ -52,7 +52,7 @@ class LogManager(logDirs: Seq[File],
                  val topicConfigs: Map[String, LogConfig], // note that this 
doesn't get updated after creation
                  val initialDefaultConfig: LogConfig,
                  val cleanerConfig: CleanerConfig,
-                 ioThreads: Int,
+                 recoveryThreadsPerDataDir: Int,
                  val flushCheckMs: Long,
                  val flushRecoveryOffsetCheckpointMs: Long,
                  val flushStartOffsetCheckpointMs: Long,
@@ -79,6 +79,7 @@ class LogManager(logDirs: Seq[File],
 
   private val _liveLogDirs: ConcurrentLinkedQueue[File] = 
createAndValidateLogDirs(logDirs, initialOfflineDirs)
   @volatile var currentDefaultConfig = initialDefaultConfig
+  @volatile private var numRecoveryThreadsPerDataDir = 
recoveryThreadsPerDataDir
 
   def reconfigureDefaultLogConfig(logConfig: LogConfig): Unit = {
     this.currentDefaultConfig = logConfig
@@ -172,6 +173,11 @@ class LogManager(logDirs: Seq[File],
     liveLogDirs
   }
 
+  def resizeRecoveryThreadPool(newSize: Int): Unit = {
+    info(s"Resizing recovery thread pool size for each data dir from 
$numRecoveryThreadsPerDataDir to $newSize")
+    numRecoveryThreadsPerDataDir = newSize
+  }
+
   // dir should be an absolute path
   def handleLogDirFailure(dir: String) {
     info(s"Stopping serving logs in dir $dir")
@@ -286,7 +292,7 @@ class LogManager(logDirs: Seq[File],
 
     for (dir <- liveLogDirs) {
       try {
-        val pool = Executors.newFixedThreadPool(ioThreads)
+        val pool = Executors.newFixedThreadPool(numRecoveryThreadsPerDataDir)
         threadPools.append(pool)
 
         val cleanShutdownFile = new File(dir, Log.CleanShutdownFile)
@@ -423,7 +429,7 @@ class LogManager(logDirs: Seq[File],
     for (dir <- liveLogDirs) {
       debug("Flushing and closing logs at " + dir)
 
-      val pool = Executors.newFixedThreadPool(ioThreads)
+      val pool = Executors.newFixedThreadPool(numRecoveryThreadsPerDataDir)
       threadPools.append(pool)
 
       val logsInDir = logsByDir.getOrElse(dir.toString, Map()).values
@@ -921,7 +927,7 @@ object LogManager {
       topicConfigs = topicConfigs,
       initialDefaultConfig = defaultLogConfig,
       cleanerConfig = cleanerConfig,
-      ioThreads = config.numRecoveryThreadsPerDataDir,
+      recoveryThreadsPerDataDir = config.numRecoveryThreadsPerDataDir,
       flushCheckMs = config.logFlushSchedulerIntervalMs,
       flushRecoveryOffsetCheckpointMs = 
config.logFlushOffsetCheckpointIntervalMs,
       flushStartOffsetCheckpointMs = 
config.logFlushStartOffsetCheckpointIntervalMs,
diff --git a/core/src/main/scala/kafka/network/RequestChannel.scala 
b/core/src/main/scala/kafka/network/RequestChannel.scala
index 561ec8d..6ec42a6 100644
--- a/core/src/main/scala/kafka/network/RequestChannel.scala
+++ b/core/src/main/scala/kafka/network/RequestChannel.scala
@@ -34,6 +34,7 @@ import org.apache.kafka.common.security.auth.KafkaPrincipal
 import org.apache.kafka.common.utils.{Sanitizer, Time}
 
 import scala.collection.mutable
+import scala.collection.JavaConverters._
 import scala.reflect.ClassTag
 
 object RequestChannel extends Logging {
@@ -239,13 +240,11 @@ object RequestChannel extends Logging {
   case object CloseConnectionAction extends ResponseAction
 }
 
-class RequestChannel(val numProcessors: Int, val queueSize: Int) extends 
KafkaMetricsGroup {
+class RequestChannel(val queueSize: Int) extends KafkaMetricsGroup {
   val metrics = new RequestChannel.Metrics
   private var responseListeners: List[(Int) => Unit] = Nil
   private val requestQueue = new ArrayBlockingQueue[BaseRequest](queueSize)
-  private val responseQueues = new 
Array[BlockingQueue[RequestChannel.Response]](numProcessors)
-  for(i <- 0 until numProcessors)
-    responseQueues(i) = new LinkedBlockingQueue[RequestChannel.Response]()
+  private val responseQueues = new ConcurrentHashMap[Int, 
BlockingQueue[RequestChannel.Response]]()
 
   newGauge(
     "RequestQueueSize",
@@ -255,18 +254,26 @@ class RequestChannel(val numProcessors: Int, val 
queueSize: Int) extends KafkaMe
   )
 
   newGauge("ResponseQueueSize", new Gauge[Int]{
-    def value = responseQueues.foldLeft(0) {(total, q) => total + q.size()}
+    def value = responseQueues.values.asScala.foldLeft(0) {(total, q) => total 
+ q.size()}
   })
 
-  for (i <- 0 until numProcessors) {
+  def addProcessor(processorId: Int): Unit = {
+    val responseQueue = new LinkedBlockingQueue[RequestChannel.Response]()
+    if (responseQueues.putIfAbsent(processorId, responseQueue) != null)
+      warn(s"Unexpected processor with processorId $processorId")
     newGauge("ResponseQueueSize",
       new Gauge[Int] {
-        def value = responseQueues(i).size()
+        def value = responseQueue.size()
       },
-      Map("processor" -> i.toString)
+      Map("processor" -> processorId.toString)
     )
   }
 
+  def removeProcessor(processorId: Int): Unit = {
+    removeMetric("ResponseQueueSize", Map("processor" -> processorId.toString))
+    responseQueues.remove(processorId)
+  }
+
   /** Send a request to be handled, potentially blocking until there is room 
in the queue for the request */
   def sendRequest(request: RequestChannel.Request) {
     requestQueue.put(request)
@@ -287,9 +294,14 @@ class RequestChannel(val numProcessors: Int, val 
queueSize: Int) extends KafkaMe
       trace(message)
     }
 
-    responseQueues(response.processor).put(response)
-    for(onResponse <- responseListeners)
-      onResponse(response.processor)
+    val responseQueue = responseQueues.get(response.processor)
+    // `responseQueue` may be null if the processor was shutdown. In this 
case, the connections
+    // are closed, so the response is dropped.
+    if (responseQueue != null) {
+      responseQueue.put(response)
+      for (onResponse <- responseListeners)
+        onResponse(response.processor)
+    }
   }
 
   /** Get the next request or block until specified time has elapsed */
@@ -302,7 +314,10 @@ class RequestChannel(val numProcessors: Int, val 
queueSize: Int) extends KafkaMe
 
   /** Get a response for the given processor if there is one */
   def receiveResponse(processor: Int): RequestChannel.Response = {
-    val response = responseQueues(processor).poll()
+    val responseQueue = responseQueues.get(processor)
+    if (responseQueue == null)
+      throw new IllegalStateException(s"receiveResponse with invalid processor 
$processor: processors=${responseQueues.keySet}")
+    val response = responseQueue.poll()
     if (response != null)
       response.request.responseDequeueTimeNanos = Time.SYSTEM.nanoseconds
     response
diff --git a/core/src/main/scala/kafka/network/SocketServer.scala 
b/core/src/main/scala/kafka/network/SocketServer.scala
index b40bf84..f6ec974 100644
--- a/core/src/main/scala/kafka/network/SocketServer.scala
+++ b/core/src/main/scala/kafka/network/SocketServer.scala
@@ -43,6 +43,7 @@ import org.slf4j.event.Level
 
 import scala.collection._
 import JavaConverters._
+import scala.collection.mutable.{ArrayBuffer, Buffer}
 import scala.util.control.ControlThrowable
 
 /**
@@ -54,9 +55,7 @@ import scala.util.control.ControlThrowable
 class SocketServer(val config: KafkaConfig, val metrics: Metrics, val time: 
Time, val credentialProvider: CredentialProvider) extends Logging with 
KafkaMetricsGroup {
 
   private val endpoints = config.listeners.map(l => l.listenerName -> l).toMap
-  private val numProcessorThreads = config.numNetworkThreads
   private val maxQueuedRequests = config.queuedMaxRequests
-  private val totalProcessorThreads = numProcessorThreads * endpoints.size
 
   private val maxConnectionsPerIp = config.maxConnectionsPerIp
   private val maxConnectionsPerIpOverrides = 
config.maxConnectionsPerIpOverrides
@@ -69,8 +68,9 @@ class SocketServer(val config: KafkaConfig, val metrics: 
Metrics, val time: Time
   private val memoryPoolDepletedTimeMetricName = 
metrics.metricName("MemoryPoolDepletedTimeTotal", "socket-server-metrics")
   memoryPoolSensor.add(new Meter(TimeUnit.MILLISECONDS, 
memoryPoolDepletedPercentMetricName, memoryPoolDepletedTimeMetricName))
   private val memoryPool = if (config.queuedMaxBytes > 0) new 
SimpleMemoryPool(config.queuedMaxBytes, config.socketRequestMaxBytes, false, 
memoryPoolSensor) else MemoryPool.NONE
-  val requestChannel = new RequestChannel(totalProcessorThreads, 
maxQueuedRequests)
-  private val processors = new Array[Processor](totalProcessorThreads)
+  val requestChannel = new RequestChannel(maxQueuedRequests)
+  private val processors = new ConcurrentHashMap[Int, Processor]()
+  private var nextProcessorId = 0
 
   private[network] val acceptors = mutable.Map[EndPoint, Acceptor]()
   private var connectionQuotas: ConnectionQuotas = _
@@ -81,41 +81,21 @@ class SocketServer(val config: KafkaConfig, val metrics: 
Metrics, val time: Time
    */
   def startup() {
     this.synchronized {
-
       connectionQuotas = new ConnectionQuotas(maxConnectionsPerIp, 
maxConnectionsPerIpOverrides)
-
-      val sendBufferSize = config.socketSendBufferBytes
-      val recvBufferSize = config.socketReceiveBufferBytes
-      val brokerId = config.brokerId
-
-      var processorBeginIndex = 0
-      config.listeners.foreach { endpoint =>
-        val listenerName = endpoint.listenerName
-        val securityProtocol = endpoint.securityProtocol
-        val processorEndIndex = processorBeginIndex + numProcessorThreads
-
-        for (i <- processorBeginIndex until processorEndIndex)
-          processors(i) = newProcessor(i, connectionQuotas, listenerName, 
securityProtocol, memoryPool)
-
-        val acceptor = new Acceptor(endpoint, sendBufferSize, recvBufferSize, 
brokerId,
-          processors.slice(processorBeginIndex, processorEndIndex), 
connectionQuotas)
-        acceptors.put(endpoint, acceptor)
-        
KafkaThread.nonDaemon(s"kafka-socket-acceptor-$listenerName-$securityProtocol-${endpoint.port}",
 acceptor).start()
-        acceptor.awaitStartup()
-
-        processorBeginIndex = processorEndIndex
-      }
+      createProcessors(config.numNetworkThreads)
     }
 
     newGauge("NetworkProcessorAvgIdlePercent",
       new Gauge[Double] {
-        private val ioWaitRatioMetricNames = processors.map { p =>
-          metrics.metricName("io-wait-ratio", "socket-server-metrics", 
p.metricTags)
-        }
 
-        def value = ioWaitRatioMetricNames.map { metricName =>
-          Option(metrics.metric(metricName)).fold(0.0)(_.value)
-        }.sum / totalProcessorThreads
+        def value = SocketServer.this.synchronized {
+          val ioWaitRatioMetricNames = processors.values.asScala.map { p =>
+            metrics.metricName("io-wait-ratio", "socket-server-metrics", 
p.metricTags)
+          }
+          ioWaitRatioMetricNames.map { metricName =>
+            Option(metrics.metric(metricName)).fold(0.0)(_.value)
+          }.sum / processors.size
+        }
       }
     )
     newGauge("MemoryPoolAvailable",
@@ -131,8 +111,41 @@ class SocketServer(val config: KafkaConfig, val metrics: 
Metrics, val time: Time
     info("Started " + acceptors.size + " acceptor threads")
   }
 
+  private def createProcessors(newProcessorsPerListener: Int): Unit = 
synchronized {
+
+    val sendBufferSize = config.socketSendBufferBytes
+    val recvBufferSize = config.socketReceiveBufferBytes
+    val brokerId = config.brokerId
+
+    val numProcessorThreads = config.numNetworkThreads
+    config.listeners.foreach { endpoint =>
+      val listenerName = endpoint.listenerName
+      val securityProtocol = endpoint.securityProtocol
+      val listenerProcessors = new ArrayBuffer[Processor]()
+
+      for (i <- 0 until newProcessorsPerListener) {
+        listenerProcessors += newProcessor(nextProcessorId, connectionQuotas, 
listenerName, securityProtocol, memoryPool)
+        requestChannel.addProcessor(nextProcessorId)
+        nextProcessorId += 1
+      }
+      listenerProcessors.foreach(p => processors.put(p.id, p))
+
+      val acceptor = acceptors.getOrElseUpdate(endpoint, {
+        val acceptor = new Acceptor(endpoint, sendBufferSize, recvBufferSize, 
brokerId, connectionQuotas)
+        
KafkaThread.nonDaemon(s"kafka-socket-acceptor-$listenerName-$securityProtocol-${endpoint.port}",
 acceptor).start()
+        acceptor.awaitStartup()
+        acceptor
+      })
+      acceptor.addProcessors(listenerProcessors)
+    }
+  }
+
   // register the processor threads for notification of responses
-  requestChannel.addResponseListener(id => processors(id).wakeup())
+  requestChannel.addResponseListener(id => {
+    val processor = processors.get(id)
+    if (processor != null)
+      processor.wakeup()
+  })
 
   /**
     * Stop processing requests and new connections.
@@ -141,13 +154,21 @@ class SocketServer(val config: KafkaConfig, val metrics: 
Metrics, val time: Time
     info("Stopping socket server request processors")
     this.synchronized {
       acceptors.values.foreach(_.shutdown)
-      processors.foreach(_.shutdown)
+      processors.asScala.values.foreach(_.shutdown)
       requestChannel.clear()
       stoppedProcessingRequests = true
     }
     info("Stopped socket server request processors")
   }
 
+  def resizeThreadPool(oldNumNetworkThreads: Int, newNumNetworkThreads: Int): 
Unit = {
+    info(s"Resizing network thread pool size for each listener from 
$oldNumNetworkThreads to $newNumNetworkThreads")
+    if (newNumNetworkThreads > oldNumNetworkThreads)
+      createProcessors(newNumNetworkThreads - oldNumNetworkThreads)
+    else if (newNumNetworkThreads < oldNumNetworkThreads)
+      acceptors.values.foreach(_.removeProcessors(oldNumNetworkThreads - 
newNumNetworkThreads, requestChannel))
+  }
+
   /**
     * Shutdown the socket server. If still processing requests, shutdown
     * acceptors and processors first.
@@ -194,7 +215,7 @@ class SocketServer(val config: KafkaConfig, val metrics: 
Metrics, val time: Time
     Option(connectionQuotas).fold(0)(_.get(address))
 
   /* For test usage */
-  private[network] def processor(index: Int): Processor = processors(index)
+  private[network] def processor(index: Int): Processor = processors.get(index)
 
 }
 
@@ -267,17 +288,28 @@ private[kafka] class Acceptor(val endPoint: EndPoint,
                               val sendBufferSize: Int,
                               val recvBufferSize: Int,
                               brokerId: Int,
-                              processors: Array[Processor],
                               connectionQuotas: ConnectionQuotas) extends 
AbstractServerThread(connectionQuotas) with KafkaMetricsGroup {
 
   private val nioSelector = NSelector.open()
   val serverChannel = openServerSocket(endPoint.host, endPoint.port)
+  private val processors = new ArrayBuffer[Processor]()
 
-  this.synchronized {
-    processors.foreach { processor =>
+  private[network] def addProcessors(newProcessors: Buffer[Processor]): Unit = 
synchronized {
+    newProcessors.foreach { processor =>
       
KafkaThread.nonDaemon(s"kafka-network-thread-$brokerId-${endPoint.listenerName}-${endPoint.securityProtocol}-${processor.id}",
         processor).start()
     }
+    processors ++= newProcessors
+  }
+
+  private[network] def removeProcessors(removeCount: Int, requestChannel: 
RequestChannel): Unit = synchronized {
+    // Shutdown `removeCount` processors. Remove them from the processor list 
first so that no more
+    // connections are assigned. Shutdown the removed processors, closing the 
selector and its connections.
+    // The processors are then removed from `requestChannel` and any pending 
responses to these processors are dropped.
+    val toRemove = processors.takeRight(removeCount)
+    processors.remove(processors.size - removeCount, removeCount)
+    toRemove.foreach(_.shutdown())
+    toRemove.foreach(processor => requestChannel.removeProcessor(processor.id))
   }
 
   /**
@@ -298,13 +330,17 @@ private[kafka] class Acceptor(val endPoint: EndPoint,
               try {
                 val key = iter.next
                 iter.remove()
-                if (key.isAcceptable)
-                  accept(key, processors(currentProcessor))
-                else
+                if (key.isAcceptable) {
+                  val processor = synchronized {
+                    currentProcessor = currentProcessor % processors.size
+                    processors(currentProcessor)
+                  }
+                  accept(key, processor)
+                } else
                   throw new IllegalStateException("Unrecognized key state for 
acceptor thread.")
 
-                // round robin to the next processor thread
-                currentProcessor = (currentProcessor + 1) % processors.length
+                // round robin to the next processor thread, 
mod(numProcessors) will be done later
+                currentProcessor = currentProcessor + 1
               } catch {
                 case e: Throwable => error("Error while accepting connection", 
e)
               }
@@ -446,8 +482,10 @@ private[kafka] class Processor(val id: Int,
       credentialProvider.tokenCache))
   // Visible to override for testing
   protected[network] def createSelector(channelBuilder: ChannelBuilder): 
KSelector = {
-    if (channelBuilder.isInstanceOf[Reconfigurable])
-      config.addReconfigurable(channelBuilder.asInstanceOf[Reconfigurable])
+    channelBuilder match {
+      case reconfigurable: Reconfigurable => 
config.addReconfigurable(reconfigurable)
+      case _ =>
+    }
     new KSelector(
       maxRequestSize,
       connectionsMaxIdleMs,
diff --git a/core/src/main/scala/kafka/server/AbstractFetcherManager.scala 
b/core/src/main/scala/kafka/server/AbstractFetcherManager.scala
index c385d4f..6d88d8d 100755
--- a/core/src/main/scala/kafka/server/AbstractFetcherManager.scala
+++ b/core/src/main/scala/kafka/server/AbstractFetcherManager.scala
@@ -17,9 +17,6 @@
 
 package kafka.server
 
-import scala.collection.mutable
-import scala.collection.Set
-import scala.collection.Map
 import kafka.utils.Logging
 import kafka.cluster.BrokerEndPoint
 import kafka.metrics.KafkaMetricsGroup
@@ -27,12 +24,17 @@ import com.yammer.metrics.core.Gauge
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.utils.Utils
 
+import scala.collection.mutable
+import scala.collection.{Map, Set}
+import scala.collection.JavaConverters._
+
 abstract class AbstractFetcherManager(protected val name: String, clientId: 
String, numFetchers: Int = 1)
   extends Logging with KafkaMetricsGroup {
   // map of (source broker_id, fetcher_id per source broker) => fetcher.
   // package private for test
   private[server] val fetcherThreadMap = new 
mutable.HashMap[BrokerIdAndFetcherId, AbstractFetcherThread]
-  private val mapLock = new Object
+  private val lock = new Object
+  private var numFetchersPerBroker = numFetchers
   this.logIdent = "[" + name + "] "
 
   newGauge(
@@ -65,13 +67,41 @@ abstract class AbstractFetcherManager(protected val name: 
String, clientId: Stri
   Map("clientId" -> clientId)
   )
 
+  def resizeThreadPool(newSize: Int): Unit = {
+    def migratePartitions(newSize: Int): Unit = {
+      fetcherThreadMap.foreach { case (id, thread) =>
+        val removedPartitions = 
thread.partitionStates.partitionStates.asScala.map { case state =>
+          state.topicPartition -> new 
BrokerAndInitialOffset(thread.sourceBroker, state.value.fetchOffset)
+        }.toMap
+        removeFetcherForPartitions(removedPartitions.keySet)
+        if (id.fetcherId >= newSize)
+          thread.shutdown()
+        addFetcherForPartitions(removedPartitions)
+      }
+    }
+    lock synchronized {
+      val currentSize = numFetchersPerBroker
+      info(s"Resizing fetcher thread pool size from $currentSize to $newSize")
+      numFetchersPerBroker = newSize
+      if (newSize != currentSize) {
+        // We could just migrate some partitions explicitly to new threads. 
But this is currently
+        // reassigning all partitions using the new thread size so that 
hash-based allocation
+        // works with partition add/delete as it did before.
+        migratePartitions(newSize)
+      }
+      shutdownIdleFetcherThreads()
+    }
+  }
+
   private def getFetcherId(topic: String, partitionId: Int) : Int = {
-    Utils.abs(31 * topic.hashCode() + partitionId) % numFetchers
+    lock synchronized {
+      Utils.abs(31 * topic.hashCode() + partitionId) % numFetchersPerBroker
+    }
   }
 
   // This method is only needed by ReplicaAlterDirManager
   def markPartitionsForTruncation(brokerId: Int, topicPartition: 
TopicPartition, truncationOffset: Long) {
-    mapLock synchronized {
+    lock synchronized {
       val fetcherId = getFetcherId(topicPartition.topic, 
topicPartition.partition)
       val brokerIdAndFetcherId = BrokerIdAndFetcherId(brokerId, fetcherId)
       fetcherThreadMap.get(brokerIdAndFetcherId).foreach { thread =>
@@ -84,7 +114,7 @@ abstract class AbstractFetcherManager(protected val name: 
String, clientId: Stri
   def createFetcherThread(fetcherId: Int, sourceBroker: BrokerEndPoint): 
AbstractFetcherThread
 
   def addFetcherForPartitions(partitionAndOffsets: Map[TopicPartition, 
BrokerAndInitialOffset]) {
-    mapLock synchronized {
+    lock synchronized {
       val partitionsPerFetcher = partitionAndOffsets.groupBy { 
case(topicPartition, brokerAndInitialFetchOffset) =>
         BrokerAndFetcherId(brokerAndInitialFetchOffset.broker, 
getFetcherId(topicPartition.topic, topicPartition.partition))}
 
@@ -117,7 +147,7 @@ abstract class AbstractFetcherManager(protected val name: 
String, clientId: Stri
   }
 
   def removeFetcherForPartitions(partitions: Set[TopicPartition]) {
-    mapLock synchronized {
+    lock synchronized {
       for (fetcher <- fetcherThreadMap.values)
         fetcher.removePartitions(partitions)
     }
@@ -125,7 +155,7 @@ abstract class AbstractFetcherManager(protected val name: 
String, clientId: Stri
   }
 
   def shutdownIdleFetcherThreads() {
-    mapLock synchronized {
+    lock synchronized {
       val keysToBeRemoved = new mutable.HashSet[BrokerIdAndFetcherId]
       for ((key, fetcher) <- fetcherThreadMap) {
         if (fetcher.partitionCount <= 0) {
@@ -138,7 +168,7 @@ abstract class AbstractFetcherManager(protected val name: 
String, clientId: Stri
   }
 
   def closeAllFetchers() {
-    mapLock synchronized {
+    lock synchronized {
       for ( (_, fetcher) <- fetcherThreadMap) {
         fetcher.initiateShutdown()
       }
diff --git a/core/src/main/scala/kafka/server/DynamicBrokerConfig.scala 
b/core/src/main/scala/kafka/server/DynamicBrokerConfig.scala
index b0dd7c0..58fa583 100755
--- a/core/src/main/scala/kafka/server/DynamicBrokerConfig.scala
+++ b/core/src/main/scala/kafka/server/DynamicBrokerConfig.scala
@@ -80,6 +80,7 @@ object DynamicBrokerConfig {
   AllDynamicConfigs ++= DynamicSecurityConfigs
   AllDynamicConfigs ++= LogCleaner.ReconfigurableConfigs
   AllDynamicConfigs ++= DynamicLogConfig.ReconfigurableConfigs
+  AllDynamicConfigs ++= DynamicThreadPool.ReconfigurableConfigs
 
   private val PerBrokerConfigs = DynamicSecurityConfigs
 
@@ -129,6 +130,7 @@ class DynamicBrokerConfig(private val kafkaConfig: 
KafkaConfig) extends Logging
   }
 
   def addReconfigurables(kafkaServer: KafkaServer): Unit = {
+    addBrokerReconfigurable(new DynamicThreadPool(kafkaServer))
     if (kafkaServer.logManager.cleaner != null)
       addBrokerReconfigurable(kafkaServer.logManager.cleaner)
     addReconfigurable(new DynamicLogConfig(kafkaServer.logManager))
@@ -449,3 +451,60 @@ class DynamicLogConfig(logManager: LogManager) extends 
Reconfigurable with Loggi
     }
   }
 }
+
+object DynamicThreadPool {
+  val ReconfigurableConfigs = Set(
+    KafkaConfig.NumIoThreadsProp,
+    KafkaConfig.NumNetworkThreadsProp,
+    KafkaConfig.NumReplicaFetchersProp,
+    KafkaConfig.NumRecoveryThreadsPerDataDirProp,
+    KafkaConfig.BackgroundThreadsProp)
+}
+
+class DynamicThreadPool(server: KafkaServer) extends BrokerReconfigurable {
+
+  override def reconfigurableConfigs(): Set[String] = {
+    DynamicThreadPool.ReconfigurableConfigs
+  }
+
+  override def validateReconfiguration(newConfig: KafkaConfig): Boolean = {
+    
newConfig.values.asScala.filterKeys(DynamicThreadPool.ReconfigurableConfigs.contains).forall
 { case (k, v) =>
+      val newValue = v.asInstanceOf[Int]
+      val oldValue = currentValue(k)
+      if (newValue != oldValue) {
+        val errorMsg = s"Dynamic thread count update validation failed for 
$k=$v"
+        if (newValue <= 0)
+          throw new ConfigException(s"$errorMsg, value should be at least 1")
+        if (newValue < oldValue / 2)
+          throw new ConfigException(s"$errorMsg, value should be at least half 
the current value $oldValue")
+        if (newValue > oldValue * 2)
+          throw new ConfigException(s"$errorMsg, value should not be greater 
than double the current value $oldValue")
+      }
+      true
+    }
+  }
+
+  override def reconfigure(oldConfig: KafkaConfig, newConfig: KafkaConfig): 
Unit = {
+    if (newConfig.numIoThreads != oldConfig.numIoThreads)
+      server.requestHandlerPool.resizeThreadPool(newConfig.numIoThreads)
+    if (newConfig.numNetworkThreads != oldConfig.numNetworkThreads)
+      server.socketServer.resizeThreadPool(oldConfig.numNetworkThreads, 
newConfig.numNetworkThreads)
+    if (newConfig.numReplicaFetchers != oldConfig.numReplicaFetchers)
+      
server.replicaManager.replicaFetcherManager.resizeThreadPool(newConfig.numReplicaFetchers)
+    if (newConfig.numRecoveryThreadsPerDataDir != 
oldConfig.numRecoveryThreadsPerDataDir)
+      
server.getLogManager.resizeRecoveryThreadPool(newConfig.numRecoveryThreadsPerDataDir)
+    if (newConfig.backgroundThreads != oldConfig.backgroundThreads)
+      server.kafkaScheduler.resizeThreadPool(newConfig.backgroundThreads)
+  }
+
+  private def currentValue(name: String): Int = {
+    name match {
+      case KafkaConfig.NumIoThreadsProp => server.config.numIoThreads
+      case KafkaConfig.NumNetworkThreadsProp => server.config.numNetworkThreads
+      case KafkaConfig.NumReplicaFetchersProp => 
server.config.numReplicaFetchers
+      case KafkaConfig.NumRecoveryThreadsPerDataDirProp => 
server.config.numRecoveryThreadsPerDataDir
+      case KafkaConfig.BackgroundThreadsProp => server.config.backgroundThreads
+      case n => throw new IllegalStateException(s"Unexpected config $n")
+    }
+  }
+}
diff --git a/core/src/main/scala/kafka/server/KafkaConfig.scala 
b/core/src/main/scala/kafka/server/KafkaConfig.scala
index bc78214..144dd65 100755
--- a/core/src/main/scala/kafka/server/KafkaConfig.scala
+++ b/core/src/main/scala/kafka/server/KafkaConfig.scala
@@ -982,11 +982,11 @@ class KafkaConfig(val props: java.util.Map[_, _], doLog: 
Boolean, dynamicConfigO
   val maxReservedBrokerId: Int = getInt(KafkaConfig.MaxReservedBrokerIdProp)
   var brokerId: Int = getInt(KafkaConfig.BrokerIdProp)
 
-  val numNetworkThreads = getInt(KafkaConfig.NumNetworkThreadsProp)
-  val backgroundThreads = getInt(KafkaConfig.BackgroundThreadsProp)
+  def numNetworkThreads = getInt(KafkaConfig.NumNetworkThreadsProp)
+  def backgroundThreads = getInt(KafkaConfig.BackgroundThreadsProp)
   val queuedMaxRequests = getInt(KafkaConfig.QueuedMaxRequestsProp)
   val queuedMaxBytes = getLong(KafkaConfig.QueuedMaxBytesProp)
-  val numIoThreads = getInt(KafkaConfig.NumIoThreadsProp)
+  def numIoThreads = getInt(KafkaConfig.NumIoThreadsProp)
   def messageMaxBytes = getInt(KafkaConfig.MessageMaxBytesProp)
   val requestTimeoutMs = getInt(KafkaConfig.RequestTimeoutMsProp)
 
@@ -1022,7 +1022,7 @@ class KafkaConfig(val props: java.util.Map[_, _], doLog: 
Boolean, dynamicConfigO
   def logSegmentBytes = getInt(KafkaConfig.LogSegmentBytesProp)
   def logFlushIntervalMessages = 
getLong(KafkaConfig.LogFlushIntervalMessagesProp)
   val logCleanerThreads = getInt(KafkaConfig.LogCleanerThreadsProp)
-  val numRecoveryThreadsPerDataDir = 
getInt(KafkaConfig.NumRecoveryThreadsPerDataDirProp)
+  def numRecoveryThreadsPerDataDir = 
getInt(KafkaConfig.NumRecoveryThreadsPerDataDirProp)
   val logFlushSchedulerIntervalMs = 
getLong(KafkaConfig.LogFlushSchedulerIntervalMsProp)
   val logFlushOffsetCheckpointIntervalMs = 
getInt(KafkaConfig.LogFlushOffsetCheckpointIntervalMsProp).toLong
   val logFlushStartOffsetCheckpointIntervalMs = 
getInt(KafkaConfig.LogFlushStartOffsetCheckpointIntervalMsProp).toLong
@@ -1066,7 +1066,7 @@ class KafkaConfig(val props: java.util.Map[_, _], doLog: 
Boolean, dynamicConfigO
   val replicaFetchMinBytes = getInt(KafkaConfig.ReplicaFetchMinBytesProp)
   val replicaFetchResponseMaxBytes = 
getInt(KafkaConfig.ReplicaFetchResponseMaxBytesProp)
   val replicaFetchBackoffMs = getInt(KafkaConfig.ReplicaFetchBackoffMsProp)
-  val numReplicaFetchers = getInt(KafkaConfig.NumReplicaFetchersProp)
+  def numReplicaFetchers = getInt(KafkaConfig.NumReplicaFetchersProp)
   val replicaHighWatermarkCheckpointIntervalMs = 
getLong(KafkaConfig.ReplicaHighWatermarkCheckpointIntervalMsProp)
   val fetchPurgatoryPurgeIntervalRequests = 
getInt(KafkaConfig.FetchPurgatoryPurgeIntervalRequestsProp)
   val producerPurgatoryPurgeIntervalRequests = 
getInt(KafkaConfig.ProducerPurgatoryPurgeIntervalRequestsProp)
diff --git a/core/src/main/scala/kafka/server/KafkaRequestHandler.scala 
b/core/src/main/scala/kafka/server/KafkaRequestHandler.scala
index a498781..d0d4121 100755
--- a/core/src/main/scala/kafka/server/KafkaRequestHandler.scala
+++ b/core/src/main/scala/kafka/server/KafkaRequestHandler.scala
@@ -21,26 +21,30 @@ import kafka.network._
 import kafka.utils._
 import kafka.metrics.KafkaMetricsGroup
 import java.util.concurrent.{CountDownLatch, TimeUnit}
+import java.util.concurrent.atomic.AtomicInteger
 
 import com.yammer.metrics.core.Meter
 import org.apache.kafka.common.internals.FatalExitError
 import org.apache.kafka.common.utils.{KafkaThread, Time}
 
+import scala.collection.mutable
+
 /**
  * A thread that answers kafka requests.
  */
 class KafkaRequestHandler(id: Int,
                           brokerId: Int,
                           val aggregateIdleMeter: Meter,
-                          val totalHandlerThreads: Int,
+                          val totalHandlerThreads: AtomicInteger,
                           val requestChannel: RequestChannel,
                           apis: KafkaApis,
                           time: Time) extends Runnable with Logging {
   this.logIdent = "[Kafka Request Handler " + id + " on Broker " + brokerId + 
"], "
-  private val latch = new CountDownLatch(1)
+  private val shutdownComplete = new CountDownLatch(1)
+  @volatile private var stopped = false
 
   def run() {
-    while(true) {
+    while (!stopped) {
       // We use a single meter for aggregate idle percentage for the thread 
pool.
       // Since meter is calculated as total_recorded_value / time_window and
       // time_window is independent of the number of threads, each recorded 
idle
@@ -50,12 +54,12 @@ class KafkaRequestHandler(id: Int,
       val req = requestChannel.receiveRequest(300)
       val endTime = time.nanoseconds
       val idleTime = endTime - startSelectTime
-      aggregateIdleMeter.mark(idleTime / totalHandlerThreads)
+      aggregateIdleMeter.mark(idleTime / totalHandlerThreads.get)
 
       req match {
         case RequestChannel.ShutdownRequest =>
           debug(s"Kafka request handler $id on broker $brokerId received shut 
down command")
-          latch.countDown()
+          shutdownComplete.countDown()
           return
 
         case request: RequestChannel.Request =>
@@ -65,21 +69,26 @@ class KafkaRequestHandler(id: Int,
             apis.handle(request)
           } catch {
             case e: FatalExitError =>
-              latch.countDown()
+              shutdownComplete.countDown()
               Exit.exit(e.statusCode)
             case e: Throwable => error("Exception when handling request", e)
           } finally {
-              request.releaseBuffer()
+            request.releaseBuffer()
           }
 
         case null => // continue
       }
     }
+    shutdownComplete.countDown()
+  }
+
+  def stop(): Unit = {
+    stopped = true
   }
 
   def initiateShutdown(): Unit = requestChannel.sendShutdownRequest()
 
-  def awaitShutdown(): Unit = latch.await()
+  def awaitShutdown(): Unit = shutdownComplete.await()
 
 }
 
@@ -89,17 +98,37 @@ class KafkaRequestHandlerPool(val brokerId: Int,
                               time: Time,
                               numThreads: Int) extends Logging with 
KafkaMetricsGroup {
 
+  private val threadPoolSize: AtomicInteger = new AtomicInteger(numThreads)
   /* a meter to track the average free capacity of the request handlers */
   private val aggregateIdleMeter = newMeter("RequestHandlerAvgIdlePercent", 
"percent", TimeUnit.NANOSECONDS)
 
   this.logIdent = "[Kafka Request Handler on Broker " + brokerId + "], "
-  val runnables = new Array[KafkaRequestHandler](numThreads)
-  for(i <- 0 until numThreads) {
-    runnables(i) = new KafkaRequestHandler(i, brokerId, aggregateIdleMeter, 
numThreads, requestChannel, apis, time)
-    KafkaThread.daemon("kafka-request-handler-" + i, runnables(i)).start()
+  val runnables = new mutable.ArrayBuffer[KafkaRequestHandler](numThreads)
+  for (i <- 0 until numThreads) {
+    createHandler(i)
+  }
+
+  def createHandler(id: Int): Unit = synchronized {
+    runnables += new KafkaRequestHandler(id, brokerId, aggregateIdleMeter, 
threadPoolSize, requestChannel, apis, time)
+    KafkaThread.daemon("kafka-request-handler-" + id, runnables(id)).start()
+  }
+
+  def resizeThreadPool(newSize: Int): Unit = synchronized {
+    val currentSize = threadPoolSize.get
+    info(s"Resizing request handler thread pool size from $currentSize to 
$newSize")
+    if (newSize > currentSize) {
+      for (i <- currentSize until newSize) {
+        createHandler(i)
+      }
+    } else if (newSize < currentSize) {
+      for (i <- 1 to (currentSize - newSize)) {
+        runnables.remove(currentSize - i).stop()
+      }
+    }
+    threadPoolSize.set(newSize)
   }
 
-  def shutdown() {
+  def shutdown(): Unit = synchronized {
     info("shutting down")
     for (handler <- runnables)
       handler.initiateShutdown()
diff --git a/core/src/main/scala/kafka/utils/KafkaScheduler.scala 
b/core/src/main/scala/kafka/utils/KafkaScheduler.scala
index d20fdd7..5407934 100755
--- a/core/src/main/scala/kafka/utils/KafkaScheduler.scala
+++ b/core/src/main/scala/kafka/utils/KafkaScheduler.scala
@@ -120,6 +120,10 @@ class KafkaScheduler(val threads: Int,
         executor.schedule(runnable, delay, unit)
     }
   }
+
+  def resizeThreadPool(newSize: Int): Unit = {
+    executor.setCorePoolSize(newSize)
+  }
   
   def isStarted: Boolean = {
     this synchronized {
diff --git 
a/core/src/test/scala/integration/kafka/server/DynamicBrokerReconfigurationTest.scala
 
b/core/src/test/scala/integration/kafka/server/DynamicBrokerReconfigurationTest.scala
index 30db7e3..819d672 100644
--- 
a/core/src/test/scala/integration/kafka/server/DynamicBrokerReconfigurationTest.scala
+++ 
b/core/src/test/scala/integration/kafka/server/DynamicBrokerReconfigurationTest.scala
@@ -377,6 +377,78 @@ class DynamicBrokerReconfigurationTest extends 
ZooKeeperTestHarness with SaslSet
     stopAndVerifyProduceConsume(producerThread, consumerThread, 
mayFailRequests = false)
   }
 
+  @Test
+  def testThreadPoolResize(): Unit = {
+    val requestHandlerPrefix = "kafka-request-handler-"
+    val networkThreadPrefix = "kafka-network-thread-"
+    val fetcherThreadPrefix = "ReplicaFetcherThread-"
+    // Executor threads and recovery threads are not verified since threads 
may not be running
+    // For others, thread count should be configuredCount * threadMultiplier * 
numBrokers
+    val threadMultiplier = Map(
+      requestHandlerPrefix -> 1,
+      networkThreadPrefix ->  2, // 2 endpoints
+      fetcherThreadPrefix -> (servers.size - 1)
+    )
+
+    // Tolerate threads left over from previous tests
+    def leftOverThreadCount(prefix: String, perBrokerCount: Int) : Int = {
+      val count = matchingThreads(prefix).size - perBrokerCount * servers.size 
* threadMultiplier(prefix)
+      if (count > 0) count else 0
+    }
+    val leftOverThreads = Map(
+      requestHandlerPrefix -> leftOverThreadCount(requestHandlerPrefix, 
servers.head.config.numIoThreads),
+      networkThreadPrefix ->  leftOverThreadCount(networkThreadPrefix, 
servers.head.config.numNetworkThreads),
+      fetcherThreadPrefix ->  leftOverThreadCount(fetcherThreadPrefix, 
servers.head.config.numReplicaFetchers)
+    )
+
+    def maybeVerifyThreadPoolSize(propName: String, size: Int, threadPrefix: 
String): Unit = {
+      val ignoreCount = leftOverThreads.getOrElse(threadPrefix, 0)
+      val expectedCountPerBroker = threadMultiplier.getOrElse(threadPrefix, 0) 
* size
+      if (expectedCountPerBroker > 0)
+        verifyThreads(threadPrefix, expectedCountPerBroker, ignoreCount)
+    }
+    def reducePoolSize(propName: String, currentSize: => Int, threadPrefix: 
String): Int = {
+      val newSize = if (currentSize / 2 == 0) 1 else currentSize / 2
+      resizeThreadPool(propName, newSize, threadPrefix)
+      newSize
+    }
+    def increasePoolSize(propName: String, currentSize: => Int, threadPrefix: 
String): Int = {
+      resizeThreadPool(propName, currentSize * 2, threadPrefix)
+      currentSize * 2
+    }
+    def resizeThreadPool(propName: String, newSize: Int, threadPrefix: 
String): Unit = {
+      val props = new Properties
+      props.put(propName, newSize.toString)
+      reconfigureServers(props, perBrokerConfig = false, (propName, 
newSize.toString))
+      maybeVerifyThreadPoolSize(propName, newSize, threadPrefix)
+    }
+    def verifyThreadPoolResize(propName: String, currentSize: => Int, 
threadPrefix: String, mayFailRequests: Boolean): Unit = {
+      maybeVerifyThreadPoolSize(propName, currentSize, threadPrefix)
+      val numRetries = if (mayFailRequests) 100 else 0
+      val (producerThread, consumerThread) = startProduceConsume(numRetries)
+      var threadPoolSize = currentSize
+      (1 to 2).foreach { _ =>
+        threadPoolSize = reducePoolSize(propName, threadPoolSize, threadPrefix)
+        Thread.sleep(100)
+        threadPoolSize = increasePoolSize(propName, threadPoolSize, 
threadPrefix)
+        Thread.sleep(100)
+      }
+      stopAndVerifyProduceConsume(producerThread, consumerThread, 
mayFailRequests)
+    }
+
+    val config = servers.head.config
+    verifyThreadPoolResize(KafkaConfig.NumIoThreadsProp, config.numIoThreads,
+      requestHandlerPrefix, mayFailRequests = false)
+    verifyThreadPoolResize(KafkaConfig.NumNetworkThreadsProp, 
config.numNetworkThreads,
+      networkThreadPrefix, mayFailRequests = true)
+    verifyThreadPoolResize(KafkaConfig.NumReplicaFetchersProp, 
config.numReplicaFetchers,
+      fetcherThreadPrefix, mayFailRequests = false)
+    verifyThreadPoolResize(KafkaConfig.BackgroundThreadsProp, 
config.backgroundThreads,
+      "kafka-scheduler-", mayFailRequests = false)
+    verifyThreadPoolResize(KafkaConfig.NumRecoveryThreadsPerDataDirProp, 
config.numRecoveryThreadsPerDataDir,
+      "", mayFailRequests = false)
+  }
+
   private def createProducer(trustStore: File, retries: Int,
                              clientId: String = "test-producer"): 
KafkaProducer[String, String] = {
     val bootstrapServers = TestUtils.bootstrapServers(servers, new 
ListenerName(SecureExternal))
@@ -560,15 +632,18 @@ class DynamicBrokerReconfigurationTest extends 
ZooKeeperTestHarness with SaslSet
     Thread.getAllStackTraces.keySet.asScala.toList.map(_.getName)
   }
 
-  private def verifyThreads(threadPrefix: String, countPerBroker: Int): Unit = 
{
+  private def matchingThreads(threadPrefix: String): List[String] = {
+    currentThreads.filter(_.startsWith(threadPrefix))
+  }
+
+  private def verifyThreads(threadPrefix: String, countPerBroker: Int, 
leftOverThreads: Int = 0): Unit = {
     val expectedCount = countPerBroker * servers.size
-    val (threads, resized) = 
TestUtils.computeUntilTrue(currentThreads.filter(_.startsWith(threadPrefix))) {
-      _.size == expectedCount
+    val (threads, resized) = 
TestUtils.computeUntilTrue(matchingThreads(threadPrefix)) { matching =>
+      matching.size >= expectedCount &&  matching.size <= expectedCount + 
leftOverThreads
     }
     assertTrue(s"Invalid threads: expected $expectedCount, got 
${threads.size}: $threads", resized)
   }
 
-
   private def startProduceConsume(retries: Int): (ProducerThread, 
ConsumerThread) = {
     val producerThread = new ProducerThread(retries)
     clientThreads += producerThread
@@ -576,11 +651,13 @@ class DynamicBrokerReconfigurationTest extends 
ZooKeeperTestHarness with SaslSet
     clientThreads += consumerThread
     consumerThread.start()
     producerThread.start()
+    TestUtils.waitUntilTrue(() => producerThread.sent >= 10, "Messages not 
sent")
     (producerThread, consumerThread)
   }
 
   private def stopAndVerifyProduceConsume(producerThread: ProducerThread, 
consumerThread: ConsumerThread,
                                                                                
    mayFailRequests: Boolean): Unit = {
+    TestUtils.waitUntilTrue(() => producerThread.sent >= 10, "Messages not 
sent")
     producerThread.shutdown()
     consumerThread.initiateShutdown()
     consumerThread.awaitShutdown()
diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala 
b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
index ea94c76..407bdb5 100755
--- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala
+++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
@@ -1024,7 +1024,7 @@ object TestUtils extends Logging {
                    topicConfigs = Map(),
                    initialDefaultConfig = defaultConfig,
                    cleanerConfig = cleanerConfig,
-                   ioThreads = 4,
+                   recoveryThreadsPerDataDir = 4,
                    flushCheckMs = 1000L,
                    flushRecoveryOffsetCheckpointMs = 10000L,
                    flushStartOffsetCheckpointMs = 10000L,

-- 
To stop receiving notification emails like this one, please contact
[email protected].

Reply via email to