This is an automated email from the ASF dual-hosted git repository.

cmccabe pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 5514f372b3e MINOR: extract jointly owned parts of BrokerServer and 
ControllerServer (#12837)
5514f372b3e is described below

commit 5514f372b3e12db1df35b257068f6bb5083111c7
Author: Colin Patrick McCabe <[email protected]>
AuthorDate: Fri Dec 2 00:27:22 2022 -0800

    MINOR: extract jointly owned parts of BrokerServer and ControllerServer 
(#12837)
    
    Extract jointly owned parts of BrokerServer and ControllerServer into 
SharedServer. Shut down
    SharedServer when the last component using it shuts down. But make sure to 
stop the raft manager
    before closing the ControllerServer's sockets.
    
    This PR also fixes a memory leak where ReplicaManager was not removing some 
topic metric callbacks
    during shutdown. Finally, we now release memory from the BatchMemoryPool in 
KafkaRaftClient#close.
    These changes should reduce memory consumption while running junit tests.
    
    Reviewers: Jason Gustafson <[email protected]>, Ismael Juma 
<[email protected]>
---
 .../java/org/apache/kafka/common/utils/Utils.java  |  12 +
 .../src/main/scala/kafka/server/BrokerServer.scala |  48 ++--
 .../main/scala/kafka/server/ControllerServer.scala |  60 ++---
 .../main/scala/kafka/server/KafkaRaftServer.scala  |  66 +----
 .../main/scala/kafka/server/ReplicaManager.scala   |  10 +
 .../src/main/scala/kafka/server/SharedServer.scala | 256 ++++++++++++++++++
 .../java/kafka/testkit/KafkaClusterTestKit.java    | 289 +++++++++++----------
 .../kafka/api/BaseAdminIntegrationTest.scala       |   5 -
 .../kafka/api/PlaintextAdminIntegrationTest.scala  |  32 ++-
 .../kafka/server/KRaftClusterTest.scala            |  17 ++
 .../kafka/server/QuorumTestHarness.scala           | 101 ++++---
 .../metadata/BrokerMetadataPublisherTest.scala     |   7 +-
 .../kafka/controller/FeatureControlManager.java    |   4 +
 .../apache/kafka/controller/QuorumController.java  |  13 +-
 .../apache/kafka/controller/QuorumFeatures.java    |   4 +
 .../kafka/controller/QuorumFeaturesTest.java       |  11 +
 .../org/apache/kafka/raft/KafkaRaftClient.java     |   4 +
 .../kafka/raft/internals/BatchMemoryPool.java      |  12 +
 .../server/fault/ProcessExitingFaultHandler.java   |  15 ++
 19 files changed, 632 insertions(+), 334 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java 
b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
index 79a907d25ab..42dcb60357b 100755
--- a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
+++ b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
@@ -986,6 +986,18 @@ public final class Utils {
             throw exception;
     }
 
+    public static void swallow(
+        Logger log,
+        String what,
+        Runnable runnable
+    ) {
+        try {
+            runnable.run();
+        } catch (Throwable e) {
+            log.warn("{} error", what, e);
+        }
+    }
+
     /**
      * An {@link AutoCloseable} interface without a throws clause in the 
signature
      *
diff --git a/core/src/main/scala/kafka/server/BrokerServer.scala 
b/core/src/main/scala/kafka/server/BrokerServer.scala
index 677a122e5b8..d6b4fa92c3a 100644
--- a/core/src/main/scala/kafka/server/BrokerServer.scala
+++ b/core/src/main/scala/kafka/server/BrokerServer.scala
@@ -22,22 +22,19 @@ import java.util
 import java.util.concurrent.atomic.AtomicBoolean
 import java.util.concurrent.locks.ReentrantLock
 import java.util.concurrent.{CompletableFuture, ExecutionException, TimeUnit, 
TimeoutException}
-
 import kafka.cluster.Broker.ServerInfo
 import kafka.coordinator.group.{GroupCoordinator, GroupCoordinatorAdapter}
 import kafka.coordinator.transaction.{ProducerIdManager, 
TransactionCoordinator}
 import kafka.log.LogManager
 import kafka.network.{DataPlaneAcceptor, SocketServer}
-import kafka.raft.RaftManager
+import kafka.raft.KafkaRaftManager
 import kafka.security.CredentialProvider
 import kafka.server.KafkaRaftServer.ControllerRole
-import kafka.server.metadata.BrokerServerMetrics
 import kafka.server.metadata.{BrokerMetadataListener, BrokerMetadataPublisher, 
BrokerMetadataSnapshotter, ClientQuotaMetadataManager, KRaftMetadataCache, 
SnapshotWriterBuilder}
 import kafka.utils.{CoreUtils, KafkaScheduler}
 import org.apache.kafka.common.feature.SupportedVersionRange
 import org.apache.kafka.common.message.ApiMessageType.ListenerType
 import 
org.apache.kafka.common.message.BrokerRegistrationRequestData.{Listener, 
ListenerCollection}
-import org.apache.kafka.common.metrics.Metrics
 import org.apache.kafka.common.network.ListenerName
 import org.apache.kafka.common.security.auth.SecurityProtocol
 import org.apache.kafka.common.security.scram.internals.ScramMechanism
@@ -46,11 +43,9 @@ import org.apache.kafka.common.utils.{AppInfoParser, 
LogContext, Time, Utils}
 import org.apache.kafka.common.{ClusterResource, Endpoint}
 import org.apache.kafka.metadata.authorizer.ClusterMetadataAuthorizer
 import org.apache.kafka.metadata.{BrokerState, VersionRange}
-import org.apache.kafka.raft.RaftConfig.AddressSpec
 import org.apache.kafka.raft.{RaftClient, RaftConfig}
 import org.apache.kafka.server.authorizer.Authorizer
 import org.apache.kafka.server.common.ApiMessageAndVersion
-import org.apache.kafka.server.fault.FaultHandler
 import org.apache.kafka.server.metrics.KafkaYammerMetrics
 import org.apache.kafka.snapshot.SnapshotWriter
 
@@ -72,19 +67,16 @@ class BrokerSnapshotWriterBuilder(raftClient: 
RaftClient[ApiMessageAndVersion])
  * A Kafka broker that runs in KRaft (Kafka Raft) mode.
  */
 class BrokerServer(
-  val config: KafkaConfig,
-  val metaProps: MetaProperties,
-  val raftManager: RaftManager[ApiMessageAndVersion],
-  val time: Time,
-  val metrics: Metrics,
-  val brokerMetrics: BrokerServerMetrics,
-  val threadNamePrefix: Option[String],
+  val sharedServer: SharedServer,
   val initialOfflineDirs: Seq[String],
-  val controllerQuorumVotersFuture: CompletableFuture[util.Map[Integer, 
AddressSpec]],
-  val fatalFaultHandler: FaultHandler,
-  val metadataLoadingFaultHandler: FaultHandler,
-  val metadataPublishingFaultHandler: FaultHandler
 ) extends KafkaBroker {
+  val threadNamePrefix = sharedServer.threadNamePrefix
+  val config = sharedServer.config
+  val time = sharedServer.time
+  def metrics = sharedServer.metrics
+
+  // Get raftManager from SharedServer. It will be initialized during startup.
+  def raftManager: KafkaRaftManager[ApiMessageAndVersion] = 
sharedServer.raftManager
 
   override def brokerState: BrokerState = Option(lifecycleManager).
     flatMap(m => Some(m.state)).getOrElse(BrokerState.NOT_RUNNING)
@@ -144,7 +136,7 @@ class BrokerServer(
 
   @volatile var brokerTopicStats: BrokerTopicStats = _
 
-  val clusterId: String = metaProps.clusterId
+  val clusterId: String = sharedServer.metaProps.clusterId
 
   var metadataSnapshotter: Option[BrokerMetadataSnapshotter] = None
 
@@ -180,6 +172,8 @@ class BrokerServer(
   override def startup(): Unit = {
     if (!maybeChangeStatus(SHUTDOWN, STARTING)) return
     try {
+      sharedServer.startForBroker()
+
       info("Starting broker")
 
       config.dynamicConfig.initialize(zkClientOpt = None)
@@ -211,7 +205,7 @@ class BrokerServer(
       tokenCache = new DelegationTokenCache(ScramMechanism.mechanismNames)
       credentialProvider = new 
CredentialProvider(ScramMechanism.mechanismNames, tokenCache)
 
-      val controllerNodes = 
RaftConfig.voterConnectionsToNodes(controllerQuorumVotersFuture.get()).asScala
+      val controllerNodes = 
RaftConfig.voterConnectionsToNodes(sharedServer.controllerQuorumVotersFuture.get()).asScala
       val controllerNodeProvider = RaftControllerNodeProvider(raftManager, 
config, controllerNodes)
 
       clientToControllerChannelManager = BrokerToControllerChannelManager(
@@ -320,8 +314,8 @@ class BrokerServer(
         threadNamePrefix,
         config.metadataSnapshotMaxNewRecordBytes,
         metadataSnapshotter,
-        brokerMetrics,
-        metadataLoadingFaultHandler)
+        sharedServer.brokerMetrics,
+        sharedServer.metadataLoaderFaultHandler)
 
       val networkListeners = new ListenerCollection()
       config.effectiveAdvertisedListeners.foreach { ep =>
@@ -349,7 +343,7 @@ class BrokerServer(
       lifecycleManager.start(
         () => metadataListener.highestMetadataOffset,
         brokerLifecycleChannelManager,
-        metaProps.clusterId,
+        sharedServer.metaProps.clusterId,
         networkListeners,
         featuresRemapped
       )
@@ -439,8 +433,8 @@ class BrokerServer(
         clientQuotaMetadataManager,
         dynamicConfigHandlers.toMap,
         authorizer,
-        fatalFaultHandler,
-        metadataPublishingFaultHandler)
+        sharedServer.initialBrokerMetadataLoadFaultHandler,
+        sharedServer.metadataPublishingFaultHandler)
 
       // Tell the metadata listener to start publishing its output, and wait 
for the first
       // publish operation to complete. This first operation will initialize 
logManager,
@@ -567,17 +561,13 @@ class BrokerServer(
 
       if (socketServer != null)
         CoreUtils.swallow(socketServer.shutdown(), this)
-      if (metrics != null)
-        CoreUtils.swallow(metrics.close(), this)
       if (brokerTopicStats != null)
         CoreUtils.swallow(brokerTopicStats.close(), this)
 
-      // Clear all reconfigurable instances stored in DynamicBrokerConfig
-      config.dynamicConfig.clear()
-
       isShuttingDown.set(false)
 
       CoreUtils.swallow(lifecycleManager.close(), this)
+      sharedServer.stopForBroker()
 
       CoreUtils.swallow(AppInfoParser.unregisterAppInfo(MetricsPrefix, 
config.nodeId.toString, metrics), this)
       info("shut down completed")
diff --git a/core/src/main/scala/kafka/server/ControllerServer.scala 
b/core/src/main/scala/kafka/server/ControllerServer.scala
index e95c870b902..f73088b30f0 100644
--- a/core/src/main/scala/kafka/server/ControllerServer.scala
+++ b/core/src/main/scala/kafka/server/ControllerServer.scala
@@ -17,36 +17,31 @@
 
 package kafka.server
 
-import java.util
 import java.util.OptionalLong
 import java.util.concurrent.locks.ReentrantLock
 import java.util.concurrent.{CompletableFuture, TimeUnit}
 import kafka.cluster.Broker.ServerInfo
 import kafka.metrics.{KafkaMetricsGroup, LinuxIoMetricsCollector}
 import kafka.network.{DataPlaneAcceptor, SocketServer}
-import kafka.raft.RaftManager
+import kafka.raft.KafkaRaftManager
 import kafka.security.CredentialProvider
 import kafka.server.KafkaConfig.{AlterConfigPolicyClassNameProp, 
CreateTopicPolicyClassNameProp}
 import kafka.server.KafkaRaftServer.BrokerRole
 import kafka.server.QuotaFactory.QuotaManagers
 import kafka.utils.{CoreUtils, Logging}
-import org.apache.kafka.clients.ApiVersions
 import org.apache.kafka.common.message.ApiMessageType.ListenerType
-import org.apache.kafka.common.metrics.Metrics
 import org.apache.kafka.common.security.scram.internals.ScramMechanism
 import 
org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache
-import org.apache.kafka.common.utils.{LogContext, Time}
+import org.apache.kafka.common.utils.LogContext
 import org.apache.kafka.common.{ClusterResource, Endpoint}
-import org.apache.kafka.controller.{Controller, ControllerMetrics, 
QuorumController, QuorumFeatures}
+import org.apache.kafka.controller.{Controller, QuorumController, 
QuorumFeatures}
 import org.apache.kafka.metadata.KafkaConfigSchema
 import org.apache.kafka.raft.RaftConfig
-import org.apache.kafka.raft.RaftConfig.AddressSpec
 import org.apache.kafka.server.authorizer.Authorizer
 import org.apache.kafka.server.common.ApiMessageAndVersion
 import org.apache.kafka.common.config.ConfigException
 import org.apache.kafka.metadata.authorizer.ClusterMetadataAuthorizer
 import org.apache.kafka.metadata.bootstrap.BootstrapMetadata
-import org.apache.kafka.server.fault.FaultHandler
 import org.apache.kafka.server.metrics.KafkaYammerMetrics
 import org.apache.kafka.server.policy.{AlterConfigPolicy, CreateTopicPolicy}
 
@@ -57,22 +52,19 @@ import scala.compat.java8.OptionConverters._
  * A Kafka controller that runs in KRaft (Kafka Raft) mode.
  */
 class ControllerServer(
-  val metaProperties: MetaProperties,
-  val config: KafkaConfig,
-  val raftManager: RaftManager[ApiMessageAndVersion],
-  val time: Time,
-  val metrics: Metrics,
-  val controllerMetrics: ControllerMetrics,
-  val threadNamePrefix: Option[String],
-  val controllerQuorumVotersFuture: CompletableFuture[util.Map[Integer, 
AddressSpec]],
+  val sharedServer: SharedServer,
   val configSchema: KafkaConfigSchema,
-  val raftApiVersions: ApiVersions,
   val bootstrapMetadata: BootstrapMetadata,
-  val metadataFaultHandler: FaultHandler,
-  val fatalFaultHandler: FaultHandler,
 ) extends Logging with KafkaMetricsGroup {
+
   import kafka.server.Server._
 
+  val config = sharedServer.config
+  val time = sharedServer.time
+  def metrics = sharedServer.metrics
+  val threadNamePrefix = sharedServer.threadNamePrefix.getOrElse("")
+  def raftManager: KafkaRaftManager[ApiMessageAndVersion] = 
sharedServer.raftManager
+
   config.dynamicConfig.initialize(zkClientOpt = None)
 
   val lock = new ReentrantLock()
@@ -111,7 +103,7 @@ class ControllerServer(
     new DynamicMetricReporterState(config.nodeId, config, metrics, clusterId)
   }
 
-  def clusterId: String = metaProperties.clusterId
+  def clusterId: String = sharedServer.metaProps.clusterId
 
   def startup(): Unit = {
     if (!maybeChangeStatus(SHUTDOWN, STARTING)) return
@@ -171,15 +163,18 @@ class ControllerServer(
         throw new ConfigException("No controller.listener.names defined for 
controller")
       }
 
-      val threadNamePrefixAsString = threadNamePrefix.getOrElse("")
+      sharedServer.startForController()
 
       createTopicPolicy = Option(config.
         getConfiguredInstance(CreateTopicPolicyClassNameProp, 
classOf[CreateTopicPolicy]))
       alterConfigPolicy = Option(config.
         getConfiguredInstance(AlterConfigPolicyClassNameProp, 
classOf[AlterConfigPolicy]))
 
-      val controllerNodes = 
RaftConfig.voterConnectionsToNodes(controllerQuorumVotersFuture.get())
-      val quorumFeatures = QuorumFeatures.create(config.nodeId, 
raftApiVersions, QuorumFeatures.defaultFeatureMap(), controllerNodes)
+      val controllerNodes = 
RaftConfig.voterConnectionsToNodes(sharedServer.controllerQuorumVotersFuture.get())
+      val quorumFeatures = QuorumFeatures.create(config.nodeId,
+        sharedServer.raftManager.apiVersions,
+        QuorumFeatures.defaultFeatureMap(),
+        controllerNodes)
 
       val controllerBuilder = {
         val leaderImbalanceCheckIntervalNs = if 
(config.autoLeaderRebalanceEnable) {
@@ -190,9 +185,9 @@ class ControllerServer(
 
         val maxIdleIntervalNs = 
config.metadataMaxIdleIntervalNs.fold(OptionalLong.empty)(OptionalLong.of)
 
-        new QuorumController.Builder(config.nodeId, metaProperties.clusterId).
+        new QuorumController.Builder(config.nodeId, 
sharedServer.metaProps.clusterId).
           setTime(time).
-          setThreadNamePrefix(threadNamePrefixAsString).
+          setThreadNamePrefix(threadNamePrefix).
           setConfigSchema(configSchema).
           setRaftClient(raftManager.client).
           setQuorumFeatures(quorumFeatures).
@@ -204,13 +199,13 @@ class ControllerServer(
           setSnapshotMaxIntervalMs(config.metadataSnapshotMaxIntervalMs).
           setLeaderImbalanceCheckIntervalNs(leaderImbalanceCheckIntervalNs).
           setMaxIdleIntervalNs(maxIdleIntervalNs).
-          setMetrics(controllerMetrics).
+          setMetrics(sharedServer.controllerMetrics).
           setCreateTopicPolicy(createTopicPolicy.asJava).
           setAlterConfigPolicy(alterConfigPolicy.asJava).
           setConfigurationValidator(new ControllerConfigurationValidator()).
           setStaticConfig(config.originals).
           setBootstrapMetadata(bootstrapMetadata).
-          setFatalFaultHandler(fatalFaultHandler)
+          setFatalFaultHandler(sharedServer.quorumControllerFaultHandler)
       }
       authorizer match {
         case Some(a: ClusterMetadataAuthorizer) => 
controllerBuilder.setAuthorizer(a)
@@ -223,7 +218,10 @@ class ControllerServer(
         doRemoteKraftSetup()
       }
 
-      quotaManagers = QuotaFactory.instantiate(config, metrics, time, 
threadNamePrefix.getOrElse(""))
+      quotaManagers = QuotaFactory.instantiate(config,
+        metrics,
+        time,
+        threadNamePrefix)
       controllerApis = new ControllerApis(socketServer.dataPlaneRequestChannel,
         authorizer,
         quotaManagers,
@@ -231,7 +229,7 @@ class ControllerServer(
         controller,
         raftManager,
         config,
-        metaProperties,
+        sharedServer.metaProps,
         controllerNodes.asScala.toSeq,
         apiVersionManager)
       controllerApisHandlerPool = new KafkaRequestHandlerPool(config.nodeId,
@@ -265,6 +263,9 @@ class ControllerServer(
     if (!maybeChangeStatus(STARTED, SHUTTING_DOWN)) return
     try {
       info("shutting down")
+      // Ensure that we're not the Raft leader prior to shutting down our 
socket server, for a
+      // smoother transition.
+      sharedServer.ensureNotRaftLeader()
       if (socketServer != null)
         CoreUtils.swallow(socketServer.stopProcessingRequests(), this)
       if (controller != null)
@@ -283,6 +284,7 @@ class ControllerServer(
       createTopicPolicy.foreach(policy => CoreUtils.swallow(policy.close(), 
this))
       alterConfigPolicy.foreach(policy => CoreUtils.swallow(policy.close(), 
this))
       socketServerFirstBoundPortFuture.completeExceptionally(new 
RuntimeException("shutting down"))
+      sharedServer.stopForController()
     } catch {
       case e: Throwable =>
         fatal("Fatal error during controller shutdown.", e)
diff --git a/core/src/main/scala/kafka/server/KafkaRaftServer.scala 
b/core/src/main/scala/kafka/server/KafkaRaftServer.scala
index 76a874b2197..1c5f3f92648 100644
--- a/core/src/main/scala/kafka/server/KafkaRaftServer.scala
+++ b/core/src/main/scala/kafka/server/KafkaRaftServer.scala
@@ -21,20 +21,15 @@ import java.util.concurrent.CompletableFuture
 import kafka.common.InconsistentNodeIdException
 import kafka.log.{LogConfig, UnifiedLog}
 import kafka.metrics.KafkaMetricsReporter
-import kafka.raft.KafkaRaftManager
 import kafka.server.KafkaRaftServer.{BrokerRole, ControllerRole}
-import kafka.server.metadata.BrokerServerMetrics
 import kafka.utils.{CoreUtils, Logging, Mx4jLoader, VerifiableProperties}
 import org.apache.kafka.common.config.{ConfigDef, ConfigResource}
 import org.apache.kafka.common.internals.Topic
 import org.apache.kafka.common.utils.{AppInfoParser, Time}
 import org.apache.kafka.common.{KafkaException, Uuid}
-import org.apache.kafka.controller.QuorumControllerMetrics
 import org.apache.kafka.metadata.bootstrap.{BootstrapDirectory, 
BootstrapMetadata}
-import org.apache.kafka.metadata.{KafkaConfigSchema, MetadataRecordSerde}
+import org.apache.kafka.metadata.KafkaConfigSchema
 import org.apache.kafka.raft.RaftConfig
-import org.apache.kafka.server.common.ApiMessageAndVersion
-import org.apache.kafka.server.fault.{LoggingFaultHandler, 
ProcessExitingFaultHandler}
 import org.apache.kafka.server.metrics.KafkaYammerMetrics
 
 import java.util.Optional
@@ -69,62 +64,30 @@ class KafkaRaftServer(
   private val controllerQuorumVotersFuture = CompletableFuture.completedFuture(
     RaftConfig.parseVoterConnections(config.quorumVoters))
 
-  private val raftManager = new KafkaRaftManager[ApiMessageAndVersion](
-    metaProps,
+  private val sharedServer = new SharedServer(
     config,
-    new MetadataRecordSerde,
-    KafkaRaftServer.MetadataPartition,
-    KafkaRaftServer.MetadataTopicId,
+    metaProps,
     time,
     metrics,
     threadNamePrefix,
-    controllerQuorumVotersFuture
+    controllerQuorumVotersFuture,
+    new StandardFaultHandlerFactory(),
   )
 
   private val broker: Option[BrokerServer] = if 
(config.processRoles.contains(BrokerRole)) {
-    val brokerMetrics = BrokerServerMetrics(metrics)
-    val fatalFaultHandler = new ProcessExitingFaultHandler()
-    val metadataLoadingFaultHandler = new LoggingFaultHandler("metadata 
loading",
-        () => brokerMetrics.metadataLoadErrorCount.getAndIncrement())
-    val metadataApplyingFaultHandler = new LoggingFaultHandler("metadata 
application",
-      () => brokerMetrics.metadataApplyErrorCount.getAndIncrement())
     Some(new BrokerServer(
-      config,
-      metaProps,
-      raftManager,
-      time,
-      metrics,
-      brokerMetrics,
-      threadNamePrefix,
-      offlineDirs,
-      controllerQuorumVotersFuture,
-      fatalFaultHandler,
-      metadataLoadingFaultHandler,
-      metadataApplyingFaultHandler
+      sharedServer,
+      offlineDirs
     ))
   } else {
     None
   }
 
   private val controller: Option[ControllerServer] = if 
(config.processRoles.contains(ControllerRole)) {
-    val controllerMetrics = new 
QuorumControllerMetrics(KafkaYammerMetrics.defaultRegistry(), time)
-    val metadataFaultHandler = new LoggingFaultHandler("controller metadata",
-      () => controllerMetrics.incrementMetadataErrorCount())
-    val fatalFaultHandler = new ProcessExitingFaultHandler()
     Some(new ControllerServer(
-      metaProps,
-      config,
-      raftManager,
-      time,
-      metrics,
-      controllerMetrics,
-      threadNamePrefix,
-      controllerQuorumVotersFuture,
+      sharedServer,
       KafkaRaftServer.configSchema,
-      raftManager.apiVersions,
       bootstrapMetadata,
-      metadataFaultHandler,
-      fatalFaultHandler
     ))
   } else {
     None
@@ -132,9 +95,6 @@ class KafkaRaftServer(
 
   override def startup(): Unit = {
     Mx4jLoader.maybeLoad()
-    // Note that we startup `RaftManager` first so that the controller and 
broker
-    // can register listeners during initialization.
-    raftManager.startup()
     controller.foreach(_.startup())
     broker.foreach(_.startup())
     AppInfoParser.registerAppInfo(Server.MetricsPrefix, 
config.brokerId.toString, metrics, time.milliseconds())
@@ -142,22 +102,18 @@ class KafkaRaftServer(
   }
 
   override def shutdown(): Unit = {
+    // In combined mode, we want to shut down the broker first, since the 
controller may be
+    // needed for controlled shutdown. Additionally, the controller shutdown 
process currently
+    // stops the raft client early on, which would disrupt broker shutdown.
     broker.foreach(_.shutdown())
-    // The order of shutdown for `RaftManager` and `ControllerServer` is 
backwards
-    // compared to `startup()`. This is because the `SocketServer` 
implementation that
-    // we rely on to receive requests is owned by `ControllerServer`, so we 
need it
-    // to stick around until graceful shutdown of `RaftManager` can be 
completed.
-    raftManager.shutdown()
     controller.foreach(_.shutdown())
     CoreUtils.swallow(AppInfoParser.unregisterAppInfo(Server.MetricsPrefix, 
config.brokerId.toString, metrics), this)
-
   }
 
   override def awaitShutdown(): Unit = {
     broker.foreach(_.awaitShutdown())
     controller.foreach(_.awaitShutdown())
   }
-
 }
 
 object KafkaRaftServer {
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala 
b/core/src/main/scala/kafka/server/ReplicaManager.scala
index 910d17267ca..b2a37479bae 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -63,6 +63,7 @@ import org.apache.kafka.metadata.LeaderConstants.NO_LEADER
 import org.apache.kafka.server.common.MetadataVersion._
 
 import java.nio.file.{Files, Paths}
+import java.util
 import scala.jdk.CollectionConverters._
 import scala.collection.{Map, Seq, Set, mutable}
 import scala.compat.java8.OptionConverters._
@@ -1928,9 +1929,18 @@ class ReplicaManager(val config: KafkaConfig,
     if (checkpointHW)
       checkpointHighWatermarks()
     replicaSelectorOpt.foreach(_.close)
+    removeAllTopicMetrics()
     info("Shut down completely")
   }
 
+  private def removeAllTopicMetrics(): Unit = {
+    val allTopics = new util.HashSet[String]
+    allPartitions.keys.foreach(partition =>
+      if (allTopics.add(partition.topic())) {
+        brokerTopicStats.removeMetrics(partition.topic())
+      })
+  }
+
   protected def createReplicaFetcherManager(metrics: Metrics, time: Time, 
threadNamePrefix: Option[String], quotaManager: ReplicationQuotaManager) = {
     new ReplicaFetcherManager(config, this, metrics, time, threadNamePrefix, 
quotaManager, () => metadataCache.metadataVersion())
   }
diff --git a/core/src/main/scala/kafka/server/SharedServer.scala 
b/core/src/main/scala/kafka/server/SharedServer.scala
new file mode 100644
index 00000000000..a420c9afa38
--- /dev/null
+++ b/core/src/main/scala/kafka/server/SharedServer.scala
@@ -0,0 +1,256 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kafka.server
+
+import kafka.raft.KafkaRaftManager
+import kafka.server.KafkaRaftServer.{BrokerRole, ControllerRole}
+import kafka.server.metadata.BrokerServerMetrics
+import kafka.utils.{CoreUtils, Logging}
+import org.apache.kafka.common.metrics.Metrics
+import org.apache.kafka.common.utils.{LogContext, Time}
+import org.apache.kafka.controller.QuorumControllerMetrics
+import org.apache.kafka.metadata.MetadataRecordSerde
+import org.apache.kafka.raft.RaftConfig.AddressSpec
+import org.apache.kafka.server.common.ApiMessageAndVersion
+import org.apache.kafka.server.fault.{FaultHandler, LoggingFaultHandler, 
ProcessExitingFaultHandler}
+import org.apache.kafka.server.metrics.KafkaYammerMetrics
+
+import java.util
+import java.util.concurrent.CompletableFuture
+
+
+/**
+ * Creates a fault handler.
+ */
+trait FaultHandlerFactory {
+  def build(
+     name: String,
+     fatal: Boolean,
+     action: Runnable
+  ): FaultHandler
+}
+
+/**
+ * The standard FaultHandlerFactory which is used when we're not in a junit 
test.
+ */
+class StandardFaultHandlerFactory extends FaultHandlerFactory {
+  override def build(
+    name: String,
+    fatal: Boolean,
+    action: Runnable
+  ): FaultHandler = {
+    if (fatal) {
+      new ProcessExitingFaultHandler(action)
+    } else {
+      new LoggingFaultHandler(name, action)
+    }
+  }
+}
+
+/**
+ * The SharedServer manages the components which are shared between the 
BrokerServer and
+ * ControllerServer. These shared components include the Raft manager, 
snapshot generator,
+ * and metadata loader. A KRaft server running in combined mode as both a 
broker and a controller
+ * will still contain only a single SharedServer instance.
+ *
+ * The SharedServer will be started as soon as either the broker or the 
controller needs it,
+ * via the appropriate function (startForBroker or startForController). 
Similarly, it will be
+ * stopped as soon as neither the broker nor the controller need it, via 
stopForBroker or
+ * stopForController. One way of thinking about this is that both the broker 
and the controller
+ * could hold a "reference" to this class, and we don't truly stop it until 
both have dropped
+ * their reference. We opted to use two booleans here rather than a reference 
count in order to
+ * make debugging easier and reduce the chance of resource leaks.
+ */
+class SharedServer(
+  val config: KafkaConfig,
+  val metaProps: MetaProperties,
+  val time: Time,
+  private val _metrics: Metrics,
+  val threadNamePrefix: Option[String],
+  val controllerQuorumVotersFuture: CompletableFuture[util.Map[Integer, 
AddressSpec]],
+  val faultHandlerFactory: FaultHandlerFactory
+) extends Logging {
+  private val logContext: LogContext = new LogContext(s"[SharedServer 
id=${config.nodeId}] ")
+  this.logIdent = logContext.logPrefix
+  private var started = false
+  private var usedByBroker: Boolean = false
+  private var usedByController: Boolean = false
+  @volatile var metrics: Metrics = _metrics
+  @volatile var raftManager: KafkaRaftManager[ApiMessageAndVersion] = _
+  @volatile var brokerMetrics: BrokerServerMetrics = _
+  @volatile var controllerMetrics: QuorumControllerMetrics = _
+
+  def isUsed(): Boolean = synchronized {
+    usedByController || usedByBroker
+  }
+
+  /**
+   * The start function called by the broker.
+   */
+  def startForBroker(): Unit = synchronized {
+    if (!isUsed()) {
+      start()
+    }
+    usedByBroker = true
+  }
+
+  /**
+   * The start function called by the controller.
+   */
+  def startForController(): Unit = synchronized {
+    if (!isUsed()) {
+      start()
+    }
+    usedByController = true
+  }
+
+  /**
+   * The stop function called by the broker.
+   */
+  def stopForBroker(): Unit = synchronized {
+    if (usedByBroker) {
+      usedByBroker = false
+      if (!isUsed()) stop()
+    }
+  }
+
+  /**
+   * The stop function called by the controller.
+   */
+  def stopForController(): Unit = synchronized {
+    if (usedByController) {
+      usedByController = false
+      if (!isUsed()) stop()
+    }
+  }
+
+  /**
+   * The fault handler to use when metadata loading fails.
+   */
+  def metadataLoaderFaultHandler: FaultHandler = 
faultHandlerFactory.build("metadata loading",
+    fatal = config.processRoles.contains(ControllerRole),
+    action = () => SharedServer.this.synchronized {
+      if (brokerMetrics != null) 
brokerMetrics.metadataLoadErrorCount.getAndIncrement()
+      if (controllerMetrics != null) 
controllerMetrics.incrementMetadataErrorCount()
+    })
+
+  /**
+   * The fault handler to use when the initial broker metadata load fails.
+   */
+  def initialBrokerMetadataLoadFaultHandler: FaultHandler = 
faultHandlerFactory.build("initial metadata loading",
+    fatal = true,
+    action = () => SharedServer.this.synchronized {
+      if (brokerMetrics != null) 
brokerMetrics.metadataApplyErrorCount.getAndIncrement()
+      if (controllerMetrics != null) 
controllerMetrics.incrementMetadataErrorCount()
+    })
+
+  /**
+   * The fault handler to use when the QuorumController experiences a fault.
+   */
+  def quorumControllerFaultHandler: FaultHandler = 
faultHandlerFactory.build("quorum controller",
+    fatal = true,
+    action = () => {}
+  )
+
+  /**
+   * The fault handler to use when metadata cannot be published.
+   */
+  def metadataPublishingFaultHandler: FaultHandler = 
faultHandlerFactory.build("metadata publishing",
+    fatal = false,
+    action = () => SharedServer.this.synchronized {
+      if (brokerMetrics != null) 
brokerMetrics.metadataApplyErrorCount.getAndIncrement()
+      if (controllerMetrics != null) 
controllerMetrics.incrementMetadataErrorCount()
+    })
+
+  private def start(): Unit = synchronized {
+    if (started) {
+      debug("SharedServer has already been started.")
+    } else {
+      info("Starting SharedServer")
+      try {
+        if (metrics == null) {
+          // Recreate the metrics object if we're restarting a stopped 
SharedServer object.
+          // This is only done in tests.
+          metrics = new Metrics()
+        }
+        config.dynamicConfig.initialize(zkClientOpt = None)
+
+        if (config.processRoles.contains(BrokerRole)) {
+          brokerMetrics = BrokerServerMetrics(metrics)
+        }
+        if (config.processRoles.contains(ControllerRole)) {
+          controllerMetrics = new 
QuorumControllerMetrics(KafkaYammerMetrics.defaultRegistry(), time)
+        }
+        raftManager = new KafkaRaftManager[ApiMessageAndVersion](
+          metaProps,
+          config,
+          new MetadataRecordSerde,
+          KafkaRaftServer.MetadataPartition,
+          KafkaRaftServer.MetadataTopicId,
+          time,
+          metrics,
+          threadNamePrefix,
+          controllerQuorumVotersFuture)
+        raftManager.startup()
+        debug("Completed SharedServer startup.")
+        started = true
+      } catch {
+        case e: Throwable => {
+          error("Got exception while starting SharedServer", e)
+          stop()
+        }
+      }
+    }
+  }
+
+  def ensureNotRaftLeader(): Unit = synchronized {
+    // Ideally, this would just resign our leadership, if we had it. But we 
don't have an API in
+    // RaftManager for that yet, so shut down the RaftManager.
+    if (raftManager != null) {
+      CoreUtils.swallow(raftManager.shutdown(), this)
+      raftManager = null
+    }
+  }
+
+  private def stop(): Unit = synchronized {
+    if (!started) {
+      debug("SharedServer is not running.")
+    } else {
+      info("Stopping SharedServer")
+      if (raftManager != null) {
+        CoreUtils.swallow(raftManager.shutdown(), this)
+        raftManager = null
+      }
+      if (controllerMetrics != null) {
+        CoreUtils.swallow(controllerMetrics.close(), this)
+        controllerMetrics = null
+      }
+      if (brokerMetrics != null) {
+        CoreUtils.swallow(brokerMetrics.close(), this)
+        brokerMetrics = null
+      }
+      if (metrics != null) {
+        CoreUtils.swallow(metrics.close(), this)
+        metrics = null
+      }
+      // Clear all reconfigurable instances stored in DynamicBrokerConfig
+      CoreUtils.swallow(config.dynamicConfig.clear(), this)
+      started = false
+    }
+  }
+}
diff --git a/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java 
b/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java
index 417c083457f..01a96689a2f 100644
--- a/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java
+++ b/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java
@@ -20,28 +20,27 @@ package kafka.testkit;
 import kafka.raft.KafkaRaftManager;
 import kafka.server.BrokerServer;
 import kafka.server.ControllerServer;
+import kafka.server.FaultHandlerFactory;
+import kafka.server.SharedServer;
 import kafka.server.KafkaConfig;
 import kafka.server.KafkaConfig$;
 import kafka.server.KafkaRaftServer;
 import kafka.server.MetaProperties;
-import kafka.server.metadata.BrokerServerMetrics$;
 import kafka.tools.StorageTool;
 import kafka.utils.Logging;
 import org.apache.kafka.clients.CommonClientConfigs;
 import org.apache.kafka.common.Node;
-import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.network.ListenerName;
 import org.apache.kafka.common.utils.ThreadUtils;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.controller.Controller;
-import org.apache.kafka.controller.MockControllerMetrics;
-import org.apache.kafka.metadata.MetadataRecordSerde;
 import org.apache.kafka.metadata.bootstrap.BootstrapMetadata;
 import org.apache.kafka.raft.RaftConfig;
 import org.apache.kafka.server.common.ApiMessageAndVersion;
 import org.apache.kafka.server.common.MetadataVersion;
+import org.apache.kafka.server.fault.FaultHandler;
 import org.apache.kafka.server.fault.MockFaultHandler;
 import org.apache.kafka.test.TestUtils;
 import org.slf4j.Logger;
@@ -115,11 +114,32 @@ public class KafkaClusterTestKit implements AutoCloseable 
{
         }
     }
 
+    static class SimpleFaultHandlerFactory implements FaultHandlerFactory {
+        private final MockFaultHandler fatalFaultHandler = new 
MockFaultHandler("fatalFaultHandler");
+        private final MockFaultHandler nonFatalFaultHandler = new 
MockFaultHandler("nonFatalFaultHandler");
+
+        MockFaultHandler fatalFaultHandler() {
+            return fatalFaultHandler;
+        }
+
+        MockFaultHandler nonFatalFaultHandler() {
+            return nonFatalFaultHandler;
+        }
+
+        @Override
+        public FaultHandler build(String name, boolean fatal, Runnable action) 
{
+            if (fatal) {
+                return fatalFaultHandler;
+            } else {
+                return nonFatalFaultHandler;
+            }
+        }
+    }
+
     public static class Builder {
         private TestKitNodes nodes;
         private Map<String, String> configProps = new HashMap<>();
-        private MockFaultHandler metadataFaultHandler = new 
MockFaultHandler("metadataFaultHandler");
-        private MockFaultHandler fatalFaultHandler = new 
MockFaultHandler("fatalFaultHandler");
+        private SimpleFaultHandlerFactory faultHandlerFactory = new 
SimpleFaultHandlerFactory();
 
         public Builder(TestKitNodes nodes) {
             this.nodes = nodes;
@@ -130,18 +150,56 @@ public class KafkaClusterTestKit implements AutoCloseable 
{
             return this;
         }
 
-        public Builder setMetadataFaultHandler(MockFaultHandler 
metadataFaultHandler) {
-            this.metadataFaultHandler = metadataFaultHandler;
-            return this;
+        private KafkaConfig createNodeConfig(TestKitNode node) {
+            BrokerNode brokerNode = nodes.brokerNodes().get(node.id());
+            ControllerNode controllerNode = 
nodes.controllerNodes().get(node.id());
+
+            Map<String, String> props = new HashMap<>(configProps);
+            props.put(KafkaConfig$.MODULE$.ProcessRolesProp(), 
roles(node.id()));
+            props.put(KafkaConfig$.MODULE$.NodeIdProp(),
+                    Integer.toString(node.id()));
+            // In combined mode, always prefer the metadata log directory of 
the controller node.
+            if (controllerNode != null) {
+                props.put(KafkaConfig$.MODULE$.MetadataLogDirProp(),
+                        controllerNode.metadataDirectory());
+            } else {
+                props.put(KafkaConfig$.MODULE$.MetadataLogDirProp(),
+                        node.metadataDirectory());
+            }
+            // Set the log.dirs according to the broker node setting (if there 
is a broker node)
+            if (brokerNode != null) {
+                props.put(KafkaConfig$.MODULE$.LogDirsProp(),
+                        String.join(",", brokerNode.logDataDirectories()));
+            }
+            props.put(KafkaConfig$.MODULE$.ListenerSecurityProtocolMapProp(),
+                    "EXTERNAL:PLAINTEXT,CONTROLLER:PLAINTEXT");
+            props.put(KafkaConfig$.MODULE$.ListenersProp(), 
listeners(node.id()));
+            props.put(KafkaConfig$.MODULE$.InterBrokerListenerNameProp(),
+                    nodes.interBrokerListenerName().value());
+            props.put(KafkaConfig$.MODULE$.ControllerListenerNamesProp(),
+                    "CONTROLLER");
+            // Note: we can't accurately set controller.quorum.voters yet, 
since we don't
+            // yet know what ports each controller will pick.  Set it to a 
dummy string
+            // for now as a placeholder.
+            String uninitializedQuorumVotersString = 
nodes.controllerNodes().keySet().stream().
+                    map(n -> String.format("%[email protected]:0", n)).
+                    collect(Collectors.joining(","));
+            props.put(RaftConfig.QUORUM_VOTERS_CONFIG, 
uninitializedQuorumVotersString);
+
+            // reduce log cleaner offset map memory usage
+            props.put(KafkaConfig$.MODULE$.LogCleanerDedupeBufferSizeProp(), 
"2097152");
+
+            // Add associated broker node property overrides
+            if (brokerNode != null) {
+                props.putAll(brokerNode.propertyOverrides());
+            }
+            return new KafkaConfig(props, false, Option.empty());
         }
 
         public KafkaClusterTestKit build() throws Exception {
             Map<Integer, ControllerServer> controllers = new HashMap<>();
             Map<Integer, BrokerServer> brokers = new HashMap<>();
-            Map<Integer, KafkaRaftManager<ApiMessageAndVersion>> raftManagers 
= new HashMap<>();
-            String uninitializedQuorumVotersString = 
nodes.controllerNodes().keySet().stream().
-                map(controllerNode -> String.format("%[email protected]:0", 
controllerNode)).
-                collect(Collectors.joining(","));
+            Map<Integer, SharedServer> jointServers = new HashMap<>();
             /*
               Number of threads = Total number of brokers + Total number of 
controllers + Total number of Raft Managers
                                 = Total number of brokers + Total number of 
controllers * 2
@@ -159,53 +217,31 @@ public class KafkaClusterTestKit implements AutoCloseable 
{
                 executorService = 
Executors.newFixedThreadPool(numOfExecutorThreads,
                     ThreadUtils.createThreadFactory("KafkaClusterTestKit%d", 
false));
                 for (ControllerNode node : nodes.controllerNodes().values()) {
-                    Map<String, String> props = new HashMap<>(configProps);
-                    props.put(KafkaConfig$.MODULE$.ProcessRolesProp(), 
roles(node.id()));
-                    props.put(KafkaConfig$.MODULE$.NodeIdProp(),
-                        Integer.toString(node.id()));
-                    props.put(KafkaConfig$.MODULE$.MetadataLogDirProp(),
-                        node.metadataDirectory());
-                    
props.put(KafkaConfig$.MODULE$.ListenerSecurityProtocolMapProp(),
-                        "EXTERNAL:PLAINTEXT,CONTROLLER:PLAINTEXT");
-                    props.put(KafkaConfig$.MODULE$.ListenersProp(), 
listeners(node.id()));
-                    
props.put(KafkaConfig$.MODULE$.InterBrokerListenerNameProp(),
-                        nodes.interBrokerListenerName().value());
-                    
props.put(KafkaConfig$.MODULE$.ControllerListenerNamesProp(),
-                        "CONTROLLER");
-                    // Note: we can't accurately set controller.quorum.voters 
yet, since we don't
-                    // yet know what ports each controller will pick.  Set it 
to a dummy string \
-                    // for now as a placeholder.
-                    props.put(RaftConfig.QUORUM_VOTERS_CONFIG, 
uninitializedQuorumVotersString);
-
-                    // reduce log cleaner offset map memory usage
-                    
props.put(KafkaConfig$.MODULE$.LogCleanerDedupeBufferSizeProp(), "2097152");
-
                     setupNodeDirectories(baseDirectory, 
node.metadataDirectory(), Collections.emptyList());
-                    KafkaConfig config = new KafkaConfig(props, false, 
Option.empty());
-
-                    String threadNamePrefix = String.format("controller%d_", 
node.id());
-                    MetaProperties metaProperties = 
MetaProperties.apply(nodes.clusterId().toString(), node.id());
-                    TopicPartition metadataPartition = new 
TopicPartition(KafkaRaftServer.MetadataTopic(), 0);
                     BootstrapMetadata bootstrapMetadata = BootstrapMetadata.
                         fromVersion(nodes.bootstrapMetadataVersion(), 
"testkit");
-                    KafkaRaftManager<ApiMessageAndVersion> raftManager = new 
KafkaRaftManager<>(
-                        metaProperties, config, new MetadataRecordSerde(), 
metadataPartition, KafkaRaftServer.MetadataTopicId(),
-                        Time.SYSTEM, new Metrics(), 
Option.apply(threadNamePrefix), connectFutureManager.future);
-                    ControllerServer controller = new ControllerServer(
-                        nodes.controllerProperties(node.id()),
-                        config,
-                        raftManager,
-                        Time.SYSTEM,
-                        new Metrics(),
-                        new MockControllerMetrics(),
-                        Option.apply(threadNamePrefix),
-                        connectFutureManager.future,
-                        KafkaRaftServer.configSchema(),
-                        raftManager.apiVersions(),
-                        bootstrapMetadata,
-                        metadataFaultHandler,
-                        fatalFaultHandler
-                    );
+                    String threadNamePrefix = 
(nodes.brokerNodes().containsKey(node.id())) ?
+                            String.format("colocated%d", node.id()) :
+                            String.format("controller%d", node.id());
+                    SharedServer sharedServer = new 
SharedServer(createNodeConfig(node),
+                            MetaProperties.apply(nodes.clusterId().toString(), 
node.id()),
+                            Time.SYSTEM,
+                            new Metrics(),
+                            Option.apply(threadNamePrefix),
+                            connectFutureManager.future,
+                            faultHandlerFactory);
+                    ControllerServer controller = null;
+                    try {
+                        controller = new ControllerServer(
+                                sharedServer,
+                                KafkaRaftServer.configSchema(),
+                                bootstrapMetadata);
+                    } catch (Throwable e) {
+                        log.error("Error creating controller {}", node.id(), 
e);
+                        Utils.swallow(log, "sharedServer.stopForController", 
() -> sharedServer.stopForController());
+                        if (controller != null) controller.shutdown();
+                        throw e;
+                    }
                     controllers.put(node.id(), controller);
                     
controller.socketServerFirstBoundPortFuture().whenComplete((port, e) -> {
                         if (e != null) {
@@ -214,61 +250,28 @@ public class KafkaClusterTestKit implements AutoCloseable 
{
                             connectFutureManager.registerPort(node.id(), port);
                         }
                     });
-                    raftManagers.put(node.id(), raftManager);
+                    jointServers.put(node.id(), sharedServer);
                 }
                 for (BrokerNode node : nodes.brokerNodes().values()) {
-                    Map<String, String> props = new HashMap<>(configProps);
-                    props.put(KafkaConfig$.MODULE$.ProcessRolesProp(), 
roles(node.id()));
-                    props.put(KafkaConfig$.MODULE$.BrokerIdProp(),
-                        Integer.toString(node.id()));
-                    props.put(KafkaConfig$.MODULE$.MetadataLogDirProp(),
-                        node.metadataDirectory());
-                    props.put(KafkaConfig$.MODULE$.LogDirsProp(),
-                        String.join(",", node.logDataDirectories()));
-                    
props.put(KafkaConfig$.MODULE$.ListenerSecurityProtocolMapProp(),
-                        "EXTERNAL:PLAINTEXT,CONTROLLER:PLAINTEXT");
-                    props.put(KafkaConfig$.MODULE$.ListenersProp(), 
listeners(node.id()));
-                    
props.put(KafkaConfig$.MODULE$.InterBrokerListenerNameProp(),
-                        nodes.interBrokerListenerName().value());
-                    
props.put(KafkaConfig$.MODULE$.ControllerListenerNamesProp(),
-                        "CONTROLLER");
-
-                    setupNodeDirectories(baseDirectory, 
node.metadataDirectory(),
-                        node.logDataDirectories());
-
-                    // Just like above, we set a placeholder voter list here 
until we
-                    // find out what ports the controllers picked.
-                    props.put(RaftConfig.QUORUM_VOTERS_CONFIG, 
uninitializedQuorumVotersString);
-                    props.putAll(node.propertyOverrides());
-                    KafkaConfig config = new KafkaConfig(props, false, 
Option.empty());
-
-                    String threadNamePrefix = String.format("broker%d_", 
node.id());
-                    MetaProperties metaProperties = 
MetaProperties.apply(nodes.clusterId().toString(), node.id());
-                    TopicPartition metadataPartition = new 
TopicPartition(KafkaRaftServer.MetadataTopic(), 0);
-                    KafkaRaftManager<ApiMessageAndVersion> raftManager;
-                    if (raftManagers.containsKey(node.id())) {
-                        raftManager = raftManagers.get(node.id());
-                    } else {
-                        raftManager = new KafkaRaftManager<>(
-                            metaProperties, config, new MetadataRecordSerde(), 
metadataPartition, KafkaRaftServer.MetadataTopicId(),
-                            Time.SYSTEM, new Metrics(), 
Option.apply(threadNamePrefix), connectFutureManager.future);
-                        raftManagers.put(node.id(), raftManager);
+                    SharedServer sharedServer = 
jointServers.computeIfAbsent(node.id(),
+                        id -> new SharedServer(createNodeConfig(node),
+                            MetaProperties.apply(nodes.clusterId().toString(), 
id),
+                            Time.SYSTEM,
+                            new Metrics(),
+                            Option.apply(String.format("broker%d_", id)),
+                            connectFutureManager.future,
+                            faultHandlerFactory));
+                    BrokerServer broker = null;
+                    try {
+                        broker = new BrokerServer(
+                                sharedServer,
+                                
JavaConverters.asScalaBuffer(Collections.<String>emptyList()).toSeq());
+                    } catch (Throwable e) {
+                        log.error("Error creating broker {}", node.id(), e);
+                        Utils.swallow(log, "sharedServer.stopForBroker", () -> 
sharedServer.stopForBroker());
+                        if (broker != null) broker.shutdown();
+                        throw e;
                     }
-                    Metrics metrics = new Metrics();
-                    BrokerServer broker = new BrokerServer(
-                        config,
-                        nodes.brokerProperties(node.id()),
-                        raftManager,
-                        Time.SYSTEM,
-                        metrics,
-                        BrokerServerMetrics$.MODULE$.apply(metrics),
-                        Option.apply(threadNamePrefix),
-                        
JavaConverters.asScalaBuffer(Collections.<String>emptyList()).toSeq(),
-                        connectFutureManager.future,
-                        fatalFaultHandler,
-                        metadataFaultHandler,
-                        metadataFaultHandler
-                    );
                     brokers.put(node.id(), broker);
                 }
             } catch (Exception e) {
@@ -279,9 +282,6 @@ public class KafkaClusterTestKit implements AutoCloseable {
                 for (BrokerServer brokerServer : brokers.values()) {
                     brokerServer.shutdown();
                 }
-                for (KafkaRaftManager<ApiMessageAndVersion> raftManager : 
raftManagers.values()) {
-                    raftManager.shutdown();
-                }
                 for (ControllerServer controller : controllers.values()) {
                     controller.shutdown();
                 }
@@ -291,9 +291,13 @@ public class KafkaClusterTestKit implements AutoCloseable {
                 }
                 throw e;
             }
-            return new KafkaClusterTestKit(executorService, nodes, controllers,
-                brokers, raftManagers, connectFutureManager, baseDirectory,
-                metadataFaultHandler, fatalFaultHandler);
+            return new KafkaClusterTestKit(executorService,
+                    nodes,
+                    controllers,
+                    brokers,
+                    connectFutureManager,
+                    baseDirectory,
+                    faultHandlerFactory);
         }
 
         private String listeners(int node) {
@@ -331,32 +335,26 @@ public class KafkaClusterTestKit implements AutoCloseable 
{
     private final TestKitNodes nodes;
     private final Map<Integer, ControllerServer> controllers;
     private final Map<Integer, BrokerServer> brokers;
-    private final Map<Integer, KafkaRaftManager<ApiMessageAndVersion>> 
raftManagers;
     private final ControllerQuorumVotersFutureManager 
controllerQuorumVotersFutureManager;
     private final File baseDirectory;
-    private final MockFaultHandler metadataFaultHandler;
-    private final MockFaultHandler fatalFaultHandler;
+    private final SimpleFaultHandlerFactory faultHandlerFactory;
 
     private KafkaClusterTestKit(
         ExecutorService executorService,
         TestKitNodes nodes,
         Map<Integer, ControllerServer> controllers,
         Map<Integer, BrokerServer> brokers,
-        Map<Integer, KafkaRaftManager<ApiMessageAndVersion>> raftManagers,
         ControllerQuorumVotersFutureManager 
controllerQuorumVotersFutureManager,
         File baseDirectory,
-        MockFaultHandler metadataFaultHandler,
-        MockFaultHandler fatalFaultHandler
+        SimpleFaultHandlerFactory faultHandlerFactory
     ) {
         this.executorService = executorService;
         this.nodes = nodes;
         this.controllers = controllers;
         this.brokers = brokers;
-        this.raftManagers = raftManagers;
         this.controllerQuorumVotersFutureManager = 
controllerQuorumVotersFutureManager;
         this.baseDirectory = baseDirectory;
-        this.metadataFaultHandler = metadataFaultHandler;
-        this.fatalFaultHandler = fatalFaultHandler;
+        this.faultHandlerFactory = faultHandlerFactory;
     }
 
     public void format() throws Exception {
@@ -370,9 +368,11 @@ public class KafkaClusterTestKit implements AutoCloseable {
             }
             for (Entry<Integer, BrokerServer> entry : brokers.entrySet()) {
                 int nodeId = entry.getKey();
-                BrokerServer broker = entry.getValue();
-                formatNodeAndLog(nodes.brokerProperties(nodeId), 
broker.config().metadataLogDir(),
-                    broker, futures::add);
+                if (!controllers.containsKey(nodeId)) {
+                    BrokerServer broker = entry.getValue();
+                    formatNodeAndLog(nodes.brokerProperties(nodeId), 
broker.config().metadataLogDir(),
+                            broker, futures::add);
+                }
             }
             for (Future<?> future: futures) {
                 future.get();
@@ -411,10 +411,6 @@ public class KafkaClusterTestKit implements AutoCloseable {
         try {
             // Note the startup order here is chosen to be consistent with
             // `KafkaRaftServer`. See comments in that class for an 
explanation.
-
-            for (KafkaRaftManager<ApiMessageAndVersion> raftManager : 
raftManagers.values()) {
-                
futures.add(controllerQuorumVotersFutureManager.future.thenRunAsync(raftManager::startup));
-            }
             for (ControllerServer controller : controllers.values()) {
                 futures.add(executorService.submit(controller::startup));
             }
@@ -505,13 +501,30 @@ public class KafkaClusterTestKit implements AutoCloseable 
{
     }
 
     public Map<Integer, KafkaRaftManager<ApiMessageAndVersion>> raftManagers() 
{
-        return raftManagers;
+        Map<Integer, KafkaRaftManager<ApiMessageAndVersion>> results = new 
HashMap<>();
+        for (BrokerServer brokerServer : brokers().values()) {
+            results.put(brokerServer.config().brokerId(), 
brokerServer.sharedServer().raftManager());
+        }
+        for (ControllerServer controllerServer : controllers().values()) {
+            if (!results.containsKey(controllerServer.config().nodeId())) {
+                results.put(controllerServer.config().nodeId(), 
controllerServer.sharedServer().raftManager());
+            }
+        }
+        return results;
     }
 
     public TestKitNodes nodes() {
         return nodes;
     }
 
+    public MockFaultHandler fatalFaultHandler() {
+        return faultHandlerFactory.fatalFaultHandler();
+    }
+
+    public MockFaultHandler nonFatalFaultHandler() {
+        return faultHandlerFactory.nonFatalFaultHandler();
+    }
+
     @Override
     public void close() throws Exception {
         List<Entry<String, Future<?>>> futureEntries = new ArrayList<>();
@@ -529,14 +542,6 @@ public class KafkaClusterTestKit implements AutoCloseable {
             }
             waitForAllFutures(futureEntries);
             futureEntries.clear();
-            for (Entry<Integer, KafkaRaftManager<ApiMessageAndVersion>> entry 
: raftManagers.entrySet()) {
-                int raftManagerId = entry.getKey();
-                KafkaRaftManager<ApiMessageAndVersion> raftManager = 
entry.getValue();
-                futureEntries.add(new SimpleImmutableEntry<>("raftManager" + 
raftManagerId,
-                    executorService.submit(raftManager::shutdown)));
-            }
-            waitForAllFutures(futureEntries);
-            futureEntries.clear();
             for (Entry<Integer, ControllerServer> entry : 
controllers.entrySet()) {
                 int controllerId = entry.getKey();
                 ControllerServer controller = entry.getValue();
@@ -555,8 +560,8 @@ public class KafkaClusterTestKit implements AutoCloseable {
             executorService.shutdownNow();
             executorService.awaitTermination(5, TimeUnit.MINUTES);
         }
-        metadataFaultHandler.maybeRethrowFirstException();
-        fatalFaultHandler.maybeRethrowFirstException();
+        faultHandlerFactory.fatalFaultHandler().maybeRethrowFirstException();
+        
faultHandlerFactory.nonFatalFaultHandler().maybeRethrowFirstException();
     }
 
     private void waitForAllFutures(List<Entry<String, Future<?>>> 
futureEntries)
diff --git 
a/core/src/test/scala/integration/kafka/api/BaseAdminIntegrationTest.scala 
b/core/src/test/scala/integration/kafka/api/BaseAdminIntegrationTest.scala
index 293b198c0ec..6fa8bc9b2bc 100644
--- a/core/src/test/scala/integration/kafka/api/BaseAdminIntegrationTest.scala
+++ b/core/src/test/scala/integration/kafka/api/BaseAdminIntegrationTest.scala
@@ -215,11 +215,6 @@ abstract class BaseAdminIntegrationTest extends 
IntegrationTestHarness with Logg
 
   override def kraftControllerConfigs(): Seq[Properties] = {
     val controllerConfig = new Properties()
-    if 
(testInfo.getTestMethod.toString.contains("testCreateTopicsReturnsConfigs")) {
-      // For testCreateTopicsReturnsConfigs, set the controller's ID to 1 so 
that the dynamic
-      // config we set for node 1 will apply to it.
-      controllerConfig.setProperty(KafkaConfig.NodeIdProp, "1")
-    }
     val controllerConfigs = Seq(controllerConfig)
     modifyConfigs(controllerConfigs)
     controllerConfigs
diff --git 
a/core/src/test/scala/integration/kafka/api/PlaintextAdminIntegrationTest.scala 
b/core/src/test/scala/integration/kafka/api/PlaintextAdminIntegrationTest.scala
index 1656af08bc1..80f6cd758b3 100644
--- 
a/core/src/test/scala/integration/kafka/api/PlaintextAdminIntegrationTest.scala
+++ 
b/core/src/test/scala/integration/kafka/api/PlaintextAdminIntegrationTest.scala
@@ -44,12 +44,14 @@ import 
org.apache.kafka.common.requests.{DeleteRecordsRequest, MetadataResponse}
 import org.apache.kafka.common.resource.{PatternType, ResourcePattern, 
ResourceType}
 import org.apache.kafka.common.utils.{Time, Utils}
 import org.apache.kafka.common.{ConsumerGroupState, ElectionType, 
TopicCollection, TopicPartition, TopicPartitionInfo, TopicPartitionReplica, 
Uuid}
+import 
org.apache.kafka.controller.ControllerRequestContextUtil.ANONYMOUS_CONTEXT
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Disabled, TestInfo}
 import org.junit.jupiter.params.ParameterizedTest
 import org.junit.jupiter.params.provider.ValueSource
 import org.slf4j.LoggerFactory
 
+import java.util.AbstractMap.SimpleImmutableEntry
 import scala.annotation.nowarn
 import scala.collection.Seq
 import scala.compat.java8.OptionConverters._
@@ -2500,7 +2502,7 @@ class PlaintextAdminIntegrationTest extends 
BaseAdminIntegrationTest {
    * Test that createTopics returns the dynamic configurations of the topics 
that were created.
    *
    * Note: this test requires some custom static broker and controller 
configurations, which are set up in
-   * BaseAdminIntegrationTest.modifyConfigs and 
BaseAdminIntegrationTest.kraftControllerConfigs.
+   * BaseAdminIntegrationTest.modifyConfigs.
    */
   @ParameterizedTest(name = TestInfoUtils.TestWithParameterizedQuorumName)
   @ValueSource(strings = Array("zk", "kraft"))
@@ -2518,19 +2520,27 @@ class PlaintextAdminIntegrationTest extends 
BaseAdminIntegrationTest {
       .all().get(15, TimeUnit.SECONDS)
 
     if (isKRaftTest()) {
+      // In KRaft mode, we don't yet support altering configs on controller 
nodes, except by setting
+      // default node configs. Therefore, we have to set the dynamic config 
directly to test this.
+      val controllerNodeResource = new 
ConfigResource(ConfigResource.Type.BROKER,
+        controllerServer.config.nodeId.toString)
+      controllerServer.controller.incrementalAlterConfigs(ANONYMOUS_CONTEXT,
+        Collections.singletonMap(controllerNodeResource,
+          Collections.singletonMap(KafkaConfig.LogCleanerDeleteRetentionMsProp,
+            new SimpleImmutableEntry(AlterConfigOp.OpType.SET, "34"))), 
false).get()
       ensureConsistentKRaftMetadata()
-    } else {
-      waitUntilTrue(() => brokers.forall(_.config.originals.getOrDefault(
-        KafkaConfig.LogCleanerDeleteRetentionMsProp, 
"").toString.equals("34")),
-        s"Timed out waiting for change to 
${KafkaConfig.LogCleanerDeleteRetentionMsProp}",
-        waitTimeMs = 60000L)
-
-      waitUntilTrue(() => brokers.forall(_.config.originals.getOrDefault(
-        KafkaConfig.LogRetentionTimeMillisProp, 
"").toString.equals("10800000")),
-        s"Timed out waiting for change to 
${KafkaConfig.LogRetentionTimeMillisProp}",
-        waitTimeMs = 60000L)
     }
 
+    waitUntilTrue(() => brokers.forall(_.config.originals.getOrDefault(
+      KafkaConfig.LogCleanerDeleteRetentionMsProp, "").toString.equals("34")),
+      s"Timed out waiting for change to 
${KafkaConfig.LogCleanerDeleteRetentionMsProp}",
+      waitTimeMs = 60000L)
+
+    waitUntilTrue(() => brokers.forall(_.config.originals.getOrDefault(
+      KafkaConfig.LogRetentionTimeMillisProp, "").toString.equals("10800000")),
+      s"Timed out waiting for change to 
${KafkaConfig.LogRetentionTimeMillisProp}",
+      waitTimeMs = 60000L)
+
     val newTopics = Seq(new NewTopic("foo", Map((0: Integer) -> 
Seq[Integer](1, 2).asJava,
       (1: Integer) -> Seq[Integer](2, 0).asJava).asJava).
       configs(Collections.singletonMap(LogConfig.IndexIntervalBytesProp, 
"9999999")),
diff --git 
a/core/src/test/scala/integration/kafka/server/KRaftClusterTest.scala 
b/core/src/test/scala/integration/kafka/server/KRaftClusterTest.scala
index f52c2a72da2..aa327f153b1 100644
--- a/core/src/test/scala/integration/kafka/server/KRaftClusterTest.scala
+++ b/core/src/test/scala/integration/kafka/server/KRaftClusterTest.scala
@@ -67,6 +67,23 @@ class KRaftClusterTest {
     }
   }
 
+  @Test
+  def testCreateClusterAndRestartNode(): Unit = {
+    val cluster = new KafkaClusterTestKit.Builder(
+      new TestKitNodes.Builder().
+        setNumBrokerNodes(1).
+        setNumControllerNodes(1).build()).build()
+    try {
+      cluster.format()
+      cluster.startup()
+      val broker = cluster.brokers().values().iterator().next()
+      broker.shutdown()
+      broker.startup()
+    } finally {
+      cluster.close()
+    }
+  }
+
   @Test
   def testCreateClusterAndWaitForBrokerInRunningState(): Unit = {
     val cluster = new KafkaClusterTestKit.Builder(
diff --git 
a/core/src/test/scala/integration/kafka/server/QuorumTestHarness.scala 
b/core/src/test/scala/integration/kafka/server/QuorumTestHarness.scala
index c0a55948cfd..4a5e3ace2a4 100755
--- a/core/src/test/scala/integration/kafka/server/QuorumTestHarness.scala
+++ b/core/src/test/scala/integration/kafka/server/QuorumTestHarness.scala
@@ -23,23 +23,18 @@ import java.util
 import java.util.{Collections, Properties}
 import java.util.concurrent.CompletableFuture
 import javax.security.auth.login.Configuration
-import kafka.raft.KafkaRaftManager
-import kafka.server.metadata.BrokerServerMetrics
 import kafka.tools.StorageTool
 import kafka.utils.{CoreUtils, Logging, TestInfoUtils, TestUtils}
 import kafka.zk.{AdminZkClient, EmbeddedZookeeper, KafkaZkClient}
 import org.apache.kafka.common.metrics.Metrics
-import org.apache.kafka.common.{TopicPartition, Uuid}
+import org.apache.kafka.common.Uuid
 import org.apache.kafka.common.security.JaasUtils
 import org.apache.kafka.common.security.auth.SecurityProtocol
 import org.apache.kafka.common.utils.{Exit, Time}
-import org.apache.kafka.controller.QuorumControllerMetrics
-import org.apache.kafka.metadata.MetadataRecordSerde
 import org.apache.kafka.metadata.bootstrap.BootstrapMetadata
 import org.apache.kafka.raft.RaftConfig.{AddressSpec, InetAddressSpec}
 import org.apache.kafka.server.common.{ApiMessageAndVersion, MetadataVersion}
 import org.apache.kafka.server.fault.{FaultHandler, MockFaultHandler}
-import org.apache.kafka.server.metrics.KafkaYammerMetrics
 import org.apache.zookeeper.client.ZKClientConfig
 import org.apache.zookeeper.{WatchedEvent, Watcher, ZooKeeper}
 import org.junit.jupiter.api.Assertions._
@@ -85,8 +80,8 @@ class ZooKeeperQuorumImplementation(
 }
 
 class KRaftQuorumImplementation(
-  val raftManager: KafkaRaftManager[ApiMessageAndVersion],
   val controllerServer: ControllerServer,
+  val faultHandlerFactory: FaultHandlerFactory,
   val metadataDir: File,
   val controllerQuorumVotersFuture: CompletableFuture[util.Map[Integer, 
AddressSpec]],
   val clusterId: String,
@@ -99,29 +94,43 @@ class KRaftQuorumImplementation(
     startup: Boolean,
     threadNamePrefix: Option[String],
   ): KafkaBroker = {
-    val metrics = new Metrics()
-    val broker = new BrokerServer(config = config,
-      metaProps = new MetaProperties(clusterId, config.nodeId),
-      raftManager = raftManager,
-      time = time,
-      metrics = metrics,
-      brokerMetrics = BrokerServerMetrics(metrics),
-      threadNamePrefix = Some("Broker%02d_".format(config.nodeId)),
-      initialOfflineDirs = Seq(),
-      controllerQuorumVotersFuture = controllerQuorumVotersFuture,
-      fatalFaultHandler = faultHandler,
-      metadataLoadingFaultHandler = faultHandler,
-      metadataPublishingFaultHandler = faultHandler)
-    if (startup) broker.startup()
-    broker
+    val sharedServer = new SharedServer(config,
+      new MetaProperties(clusterId, config.nodeId),
+      Time.SYSTEM,
+      new Metrics(),
+      Option("Broker%02d_".format(config.nodeId)),
+      controllerQuorumVotersFuture,
+      faultHandlerFactory)
+    var broker: BrokerServer = null
+    try {
+      broker = new BrokerServer(sharedServer,
+        initialOfflineDirs = Seq())
+      if (startup) broker.startup()
+      broker
+    } catch {
+      case e: Throwable => {
+        if (broker != null) CoreUtils.swallow(broker.shutdown(), log)
+        CoreUtils.swallow(sharedServer.stopForBroker(), log)
+        throw e
+      }
+    }
   }
 
   override def shutdown(): Unit = {
-    CoreUtils.swallow(raftManager.shutdown(), log)
     CoreUtils.swallow(controllerServer.shutdown(), log)
   }
 }
 
+class QuorumTestHarnessFaultHandlerFactory(
+  val faultHandler: MockFaultHandler
+) extends FaultHandlerFactory {
+  override def build(
+    name: String,
+    fatal: Boolean,
+    action: Runnable
+  ): FaultHandler = faultHandler
+}
+
 @Tag("integration")
 abstract class QuorumTestHarness extends Logging {
   val zkConnectionTimeout = 10000
@@ -199,7 +208,9 @@ abstract class QuorumTestHarness extends Logging {
     }
   }
 
-  val faultHandler = new MockFaultHandler("quorumTestHarnessFaultHandler")
+  val faultHandlerFactory = new QuorumTestHarnessFaultHandlerFactory(new 
MockFaultHandler("quorumTestHarnessFaultHandler"))
+
+  val faultHandler = faultHandlerFactory.faultHandler
 
   // Note: according to the junit documentation: "JUnit Jupiter does not 
guarantee the execution
   // order of multiple @BeforeEach methods that are declared within a single 
test class or test
@@ -288,7 +299,6 @@ abstract class QuorumTestHarness extends Logging {
     val metadataDir = TestUtils.tempDir()
     val metaProperties = new MetaProperties(Uuid.randomUuid().toString, nodeId)
     formatDirectories(immutable.Seq(metadataDir.getAbsolutePath), 
metaProperties)
-    val controllerMetrics = new Metrics()
     props.setProperty(KafkaConfig.MetadataLogDirProp, 
metadataDir.getAbsolutePath)
     val proto = controllerListenerSecurityProtocol.toString
     props.setProperty(KafkaConfig.ListenerSecurityProtocolMapProp, 
s"CONTROLLER:${proto}")
@@ -296,34 +306,20 @@ abstract class QuorumTestHarness extends Logging {
     props.setProperty(KafkaConfig.ControllerListenerNamesProp, "CONTROLLER")
     props.setProperty(KafkaConfig.QuorumVotersProp, s"${nodeId}@localhost:0")
     val config = new KafkaConfig(props)
-    val threadNamePrefix = "Controller_" + testInfo.getDisplayName
     val controllerQuorumVotersFuture = new CompletableFuture[util.Map[Integer, 
AddressSpec]]
-    val raftManager = new KafkaRaftManager(
-      metaProperties = metaProperties,
-      config = config,
-      recordSerde = MetadataRecordSerde.INSTANCE,
-      topicPartition = new TopicPartition(KafkaRaftServer.MetadataTopic, 0),
-      topicId = KafkaRaftServer.MetadataTopicId,
-      time = Time.SYSTEM,
-      metrics = controllerMetrics,
-      threadNamePrefixOpt = Option(threadNamePrefix),
-      controllerQuorumVotersFuture = controllerQuorumVotersFuture)
+    val sharedServer = new SharedServer(config,
+      metaProperties,
+      Time.SYSTEM,
+      new Metrics(),
+      Option("Controller_" + testInfo.getDisplayName),
+      controllerQuorumVotersFuture,
+      faultHandlerFactory)
     var controllerServer: ControllerServer = null
     try {
       controllerServer = new ControllerServer(
-        metaProperties = metaProperties,
-        config = config,
-        raftManager = raftManager,
-        time = Time.SYSTEM,
-        metrics = controllerMetrics,
-        controllerMetrics = new 
QuorumControllerMetrics(KafkaYammerMetrics.defaultRegistry(), Time.SYSTEM),
-        threadNamePrefix = Option(threadNamePrefix),
-        controllerQuorumVotersFuture = controllerQuorumVotersFuture,
-        configSchema = KafkaRaftServer.configSchema,
-        raftApiVersions = raftManager.apiVersions,
-        bootstrapMetadata = BootstrapMetadata.fromVersion(metadataVersion, 
"test harness"),
-        metadataFaultHandler = faultHandler,
-        fatalFaultHandler = faultHandler
+        sharedServer,
+        KafkaRaftServer.configSchema,
+        BootstrapMetadata.fromVersion(metadataVersion, "test harness")
       )
       controllerServer.socketServerFirstBoundPortFuture.whenComplete((port, e) 
=> {
         if (e != null) {
@@ -335,15 +331,14 @@ abstract class QuorumTestHarness extends Logging {
         }
       })
       controllerServer.startup()
-      raftManager.startup()
     } catch {
       case e: Throwable =>
-        CoreUtils.swallow(raftManager.shutdown(), this)
         if (controllerServer != null) 
CoreUtils.swallow(controllerServer.shutdown(), this)
+        CoreUtils.swallow(sharedServer.stopForController(), this)
         throw e
     }
-    new KRaftQuorumImplementation(raftManager,
-      controllerServer,
+    new KRaftQuorumImplementation(controllerServer,
+      faultHandlerFactory,
       metadataDir,
       controllerQuorumVotersFuture,
       metaProperties.clusterId,
diff --git 
a/core/src/test/scala/unit/kafka/server/metadata/BrokerMetadataPublisherTest.scala
 
b/core/src/test/scala/unit/kafka/server/metadata/BrokerMetadataPublisherTest.scala
index b0936d12f3e..8874a235a52 100644
--- 
a/core/src/test/scala/unit/kafka/server/metadata/BrokerMetadataPublisherTest.scala
+++ 
b/core/src/test/scala/unit/kafka/server/metadata/BrokerMetadataPublisherTest.scala
@@ -255,12 +255,11 @@ class BrokerMetadataPublisherTest {
 
   @Test
   def testExceptionInUpdateCoordinator(): Unit = {
-    val errorHandler = new MockFaultHandler("publisher")
     val cluster = new KafkaClusterTestKit.Builder(
       new TestKitNodes.Builder().
         setNumBrokerNodes(1).
         setNumControllerNodes(1).build()).
-      setMetadataFaultHandler(errorHandler).build()
+      build()
     try {
       cluster.format()
       cluster.startup()
@@ -279,11 +278,11 @@ class BrokerMetadataPublisherTest {
         admin.close()
       }
       TestUtils.retry(60000) {
-        assertTrue(Option(errorHandler.firstException()).
+        assertTrue(Option(cluster.nonFatalFaultHandler().firstException()).
           flatMap(e => 
Option(e.getMessage())).getOrElse("(none)").contains("injected failure"))
       }
     } finally {
-      errorHandler.setIgnore(true)
+      cluster.nonFatalFaultHandler().setIgnore(true)
       cluster.close()
     }
   }
diff --git 
a/metadata/src/main/java/org/apache/kafka/controller/FeatureControlManager.java 
b/metadata/src/main/java/org/apache/kafka/controller/FeatureControlManager.java
index 97942835720..b3758586cb8 100644
--- 
a/metadata/src/main/java/org/apache/kafka/controller/FeatureControlManager.java
+++ 
b/metadata/src/main/java/org/apache/kafka/controller/FeatureControlManager.java
@@ -347,4 +347,8 @@ public class FeatureControlManager {
     FeatureControlIterator iterator(long epoch) {
         return new FeatureControlIterator(epoch);
     }
+
+    boolean isControllerId(int nodeId) {
+        return quorumFeatures.isControllerId(nodeId);
+    }
 }
diff --git 
a/metadata/src/main/java/org/apache/kafka/controller/QuorumController.java 
b/metadata/src/main/java/org/apache/kafka/controller/QuorumController.java
index 8bf46dd8c2b..37e96b294d7 100644
--- a/metadata/src/main/java/org/apache/kafka/controller/QuorumController.java
+++ b/metadata/src/main/java/org/apache/kafka/controller/QuorumController.java
@@ -381,17 +381,18 @@ public final class QuorumController implements Controller 
{
                     // Cluster configs are always allowed.
                     if (configResource.name().isEmpty()) break;
 
-                    // Otherwise, check that the broker ID is valid.
-                    int brokerId;
+                    // Otherwise, check that the node ID is valid.
+                    int nodeId;
                     try {
-                        brokerId = Integer.parseInt(configResource.name());
+                        nodeId = Integer.parseInt(configResource.name());
                     } catch (NumberFormatException e) {
                         throw new InvalidRequestException("Invalid broker name 
" +
                             configResource.name());
                     }
-                    if 
(!clusterControl.brokerRegistrations().containsKey(brokerId)) {
-                        throw new BrokerIdNotRegisteredException("No broker 
with id " +
-                            brokerId + " found.");
+                    if 
(!(clusterControl.brokerRegistrations().containsKey(nodeId) ||
+                            featureControl.isControllerId(nodeId))) {
+                        throw new BrokerIdNotRegisteredException("No node with 
id " +
+                            nodeId + " found.");
                     }
                     break;
                 case TOPIC:
diff --git 
a/metadata/src/main/java/org/apache/kafka/controller/QuorumFeatures.java 
b/metadata/src/main/java/org/apache/kafka/controller/QuorumFeatures.java
index 36725c25185..58a7aea2af3 100644
--- a/metadata/src/main/java/org/apache/kafka/controller/QuorumFeatures.java
+++ b/metadata/src/main/java/org/apache/kafka/controller/QuorumFeatures.java
@@ -124,4 +124,8 @@ public class QuorumFeatures {
     VersionRange localSupportedFeature(String featureName) {
         return localSupportedFeatures.getOrDefault(featureName, DISABLED);
     }
+
+    boolean isControllerId(int nodeId) {
+        return quorumNodeIds.contains(nodeId);
+    }
 }
diff --git 
a/metadata/src/test/java/org/apache/kafka/controller/QuorumFeaturesTest.java 
b/metadata/src/test/java/org/apache/kafka/controller/QuorumFeaturesTest.java
index 7d8ba5bfec2..7cd6e6c5cbd 100644
--- a/metadata/src/test/java/org/apache/kafka/controller/QuorumFeaturesTest.java
+++ b/metadata/src/test/java/org/apache/kafka/controller/QuorumFeaturesTest.java
@@ -35,6 +35,8 @@ import java.util.Optional;
 
 import static java.util.Collections.emptyMap;
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class QuorumFeaturesTest {
     private final static Map<String, VersionRange> LOCAL;
@@ -89,4 +91,13 @@ public class QuorumFeaturesTest {
         });
         return new NodeApiVersions(Collections.emptyList(), features);
     }
+
+    @Test
+    public void testIsControllerId() {
+        QuorumFeatures quorumFeatures = new QuorumFeatures(0, new 
ApiVersions(), LOCAL, Arrays.asList(0, 1, 2));
+        assertTrue(quorumFeatures.isControllerId(0));
+        assertTrue(quorumFeatures.isControllerId(1));
+        assertTrue(quorumFeatures.isControllerId(2));
+        assertFalse(quorumFeatures.isControllerId(3));
+    }
 }
diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java 
b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
index 249350a39ac..1cd6058eb0c 100644
--- a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
+++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
@@ -2368,6 +2368,10 @@ public class KafkaRaftClient<T> implements RaftClient<T> 
{
         if (kafkaRaftMetrics != null) {
             kafkaRaftMetrics.close();
         }
+        if (memoryPool instanceof BatchMemoryPool) {
+            BatchMemoryPool batchMemoryPool = (BatchMemoryPool) memoryPool;
+            batchMemoryPool.releaseRetained();
+        }
     }
 
     QuorumState quorum() {
diff --git 
a/raft/src/main/java/org/apache/kafka/raft/internals/BatchMemoryPool.java 
b/raft/src/main/java/org/apache/kafka/raft/internals/BatchMemoryPool.java
index 5120d6928d3..0239c84f733 100644
--- a/raft/src/main/java/org/apache/kafka/raft/internals/BatchMemoryPool.java
+++ b/raft/src/main/java/org/apache/kafka/raft/internals/BatchMemoryPool.java
@@ -113,6 +113,18 @@ public class BatchMemoryPool implements MemoryPool {
         }
     }
 
+    /**
+     * Release the retained buffers in the free pool.
+     */
+    public void releaseRetained() {
+        lock.lock();
+        try {
+            free.clear();
+        } finally {
+            lock.unlock();
+        }
+    }
+
     @Override
     public long size() {
         lock.lock();
diff --git 
a/server-common/src/main/java/org/apache/kafka/server/fault/ProcessExitingFaultHandler.java
 
b/server-common/src/main/java/org/apache/kafka/server/fault/ProcessExitingFaultHandler.java
index b7c0d241a2a..b67bcf3fa7d 100644
--- 
a/server-common/src/main/java/org/apache/kafka/server/fault/ProcessExitingFaultHandler.java
+++ 
b/server-common/src/main/java/org/apache/kafka/server/fault/ProcessExitingFaultHandler.java
@@ -29,6 +29,16 @@ import org.apache.kafka.common.utils.Exit;
 public class ProcessExitingFaultHandler implements FaultHandler {
     private static final Logger log = 
LoggerFactory.getLogger(ProcessExitingFaultHandler.class);
 
+    private final Runnable action;
+
+    public ProcessExitingFaultHandler() {
+        this.action = () -> { };
+    }
+
+    public ProcessExitingFaultHandler(Runnable action) {
+        this.action = action;
+    }
+
     @Override
     public RuntimeException handleFault(String failureMessage, Throwable 
cause) {
         if (cause == null) {
@@ -36,6 +46,11 @@ public class ProcessExitingFaultHandler implements 
FaultHandler {
         } else {
             log.error("Encountered fatal fault: {}", failureMessage, cause);
         }
+        try {
+            action.run();
+        } catch (Throwable e) {
+            log.error("Failed to run ProcessExitingFaultHandler action.", e);
+        }
         Exit.exit(1);
         return null;
     }

Reply via email to