This is an automated email from the ASF dual-hosted git repository. cmccabe pushed a commit to branch 3.3 in repository https://gitbox.apache.org/repos/asf/kafka.git
commit 480e97914e1146ba79ae883c4e987f5de20702cb Author: David Arthur <[email protected]> AuthorDate: Tue Jul 26 19:08:59 2022 -0400 KAFKA-13166 Fix missing ControllerApis error handling (#12403) Makes all ControllerApis request handlers return a `CompletableFuture[Unit]`. Also adds an additional completion stage which ensures we capture errors thrown during response building. Reviewed-by: Colin P. McCabe <[email protected]> --- core/src/main/scala/kafka/server/AclApis.scala | 32 ++++-- .../main/scala/kafka/server/ControllerApis.scala | 121 ++++++++++++--------- .../unit/kafka/server/ControllerApisTest.scala | 30 +++++ 3 files changed, 124 insertions(+), 59 deletions(-) diff --git a/core/src/main/scala/kafka/server/AclApis.scala b/core/src/main/scala/kafka/server/AclApis.scala index 97b685bc0aa..485cafeca20 100644 --- a/core/src/main/scala/kafka/server/AclApis.scala +++ b/core/src/main/scala/kafka/server/AclApis.scala @@ -24,14 +24,16 @@ import org.apache.kafka.common.acl.AclOperation._ import org.apache.kafka.common.acl.AclBinding import org.apache.kafka.common.errors._ import org.apache.kafka.common.message.CreateAclsResponseData.AclCreationResult +import org.apache.kafka.common.message.DeleteAclsResponseData.DeleteAclsFilterResult import org.apache.kafka.common.message._ import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.requests._ import org.apache.kafka.common.resource.Resource.CLUSTER_NAME import org.apache.kafka.common.resource.ResourceType import org.apache.kafka.server.authorizer._ -import java.util +import java.util +import java.util.concurrent.CompletableFuture import scala.collection.mutable.ArrayBuffer import scala.collection.mutable import scala.compat.java8.OptionConverters._ @@ -53,7 +55,7 @@ class AclApis(authHelper: AuthHelper, def close(): Unit = alterAclsPurgatory.shutdown() - def handleDescribeAcls(request: RequestChannel.Request): Unit = { + def handleDescribeAcls(request: RequestChannel.Request): CompletableFuture[Unit] = { authHelper.authorizeClusterOperation(request, DESCRIBE) val describeAclsRequest = request.body[DescribeAclsRequest] authorizer match { @@ -74,9 +76,10 @@ class AclApis(authHelper: AuthHelper, .setResources(DescribeAclsResponse.aclsResources(returnedAcls)), describeAclsRequest.version)) } + CompletableFuture.completedFuture[Unit](()) } - def handleCreateAcls(request: RequestChannel.Request): Unit = { + def handleCreateAcls(request: RequestChannel.Request): CompletableFuture[Unit] = { authHelper.authorizeClusterOperation(request, ALTER) val createAclsRequest = request.body[CreateAclsRequest] @@ -84,6 +87,7 @@ class AclApis(authHelper: AuthHelper, case None => requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => createAclsRequest.getErrorResponse(requestThrottleMs, new SecurityDisabledException("No Authorizer is configured."))) + CompletableFuture.completedFuture[Unit](()) case Some(auth) => val allBindings = createAclsRequest.aclCreations.asScala.map(CreateAclsRequest.aclBinding) val errorResults = mutable.Map[AclBinding, AclCreateResult]() @@ -103,6 +107,7 @@ class AclApis(authHelper: AuthHelper, validBindings += acl } + val future = new CompletableFuture[util.List[AclCreationResult]]() val createResults = auth.createAcls(request.context, validBindings.asJava).asScala.map(_.toCompletableFuture) def sendResponseCallback(): Unit = { @@ -117,17 +122,20 @@ class AclApis(authHelper: AuthHelper, } creationResult } + future.complete(aclCreationResults.asJava) + } + alterAclsPurgatory.tryCompleteElseWatch(config.connectionsMaxIdleMs, createResults, sendResponseCallback) + + future.thenApply[Unit] { aclCreationResults => requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => new CreateAclsResponse(new CreateAclsResponseData() .setThrottleTimeMs(requestThrottleMs) - .setResults(aclCreationResults.asJava))) + .setResults(aclCreationResults))) } - - alterAclsPurgatory.tryCompleteElseWatch(config.connectionsMaxIdleMs, createResults, sendResponseCallback) } } - def handleDeleteAcls(request: RequestChannel.Request): Unit = { + def handleDeleteAcls(request: RequestChannel.Request): CompletableFuture[Unit] = { authHelper.authorizeClusterOperation(request, ALTER) val deleteAclsRequest = request.body[DeleteAclsRequest] authorizer match { @@ -135,13 +143,20 @@ class AclApis(authHelper: AuthHelper, requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => deleteAclsRequest.getErrorResponse(requestThrottleMs, new SecurityDisabledException("No Authorizer is configured."))) + CompletableFuture.completedFuture[Unit](()) case Some(auth) => + val future = new CompletableFuture[util.List[DeleteAclsFilterResult]]() val deleteResults = auth.deleteAcls(request.context, deleteAclsRequest.filters) .asScala.map(_.toCompletableFuture).toList def sendResponseCallback(): Unit = { val filterResults = deleteResults.map(_.get).map(DeleteAclsResponse.filterResult).asJava + future.complete(filterResults) + } + + alterAclsPurgatory.tryCompleteElseWatch(config.connectionsMaxIdleMs, deleteResults, sendResponseCallback) + future.thenApply[Unit] { filterResults => requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => new DeleteAclsResponse( new DeleteAclsResponseData() @@ -149,7 +164,6 @@ class AclApis(authHelper: AuthHelper, .setFilterResults(filterResults), deleteAclsRequest.version)) } - alterAclsPurgatory.tryCompleteElseWatch(config.connectionsMaxIdleMs, deleteResults, sendResponseCallback) } } -} + } diff --git a/core/src/main/scala/kafka/server/ControllerApis.scala b/core/src/main/scala/kafka/server/ControllerApis.scala index 74bc4dd4067..efb6a36c3db 100644 --- a/core/src/main/scala/kafka/server/ControllerApis.scala +++ b/core/src/main/scala/kafka/server/ControllerApis.scala @@ -20,7 +20,7 @@ package kafka.server import java.util import java.util.{Collections, OptionalLong} import java.util.Map.Entry -import java.util.concurrent.{CompletableFuture, ExecutionException} +import java.util.concurrent.{CompletableFuture, CompletionException} import kafka.network.RequestChannel import kafka.raft.RaftManager import kafka.server.QuotaFactory.QuotaManagers @@ -78,7 +78,7 @@ class ControllerApis(val requestChannel: RequestChannel, override def handle(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { try { - request.header.apiKey match { + val handlerFuture: CompletableFuture[Unit] = request.header.apiKey match { case ApiKeys.FETCH => handleFetch(request) case ApiKeys.FETCH_SNAPSHOT => handleFetchSnapshot(request) case ApiKeys.CREATE_TOPICS => handleCreateTopics(request) @@ -109,10 +109,24 @@ class ControllerApis(val requestChannel: RequestChannel, case ApiKeys.UPDATE_FEATURES => handleUpdateFeatures(request) case _ => throw new ApiException(s"Unsupported ApiKey ${request.context.header.apiKey}") } + + // This catches exceptions in the future and subsequent completion stages returned by the request handlers. + handlerFuture.whenComplete { (_, exception) => + if (exception != null) { + // CompletionException does not include the stack frames in its "cause" exception, so we need to + // log the original exception here + error(s"Unexpected error handling request ${request.requestDesc(true)} " + + s"with context ${request.context}", exception) + + // For building the correct error request, we do need send the "cause" exception + val actualException = if (exception.isInstanceOf[CompletionException]) exception.getCause else exception + requestHelper.handleError(request, actualException) + } + } } catch { case e: FatalExitError => throw e - case e: Throwable => { - val t = if (e.isInstanceOf[ExecutionException]) e.getCause else e + case t: Throwable => { + // This catches exceptions in the blocking parts of the request handlers error(s"Unexpected error handling request ${request.requestDesc(true)} " + s"with context ${request.context}", t) requestHelper.handleError(request, t) @@ -125,38 +139,41 @@ class ControllerApis(val requestChannel: RequestChannel, } } - def handleEnvelopeRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + def handleEnvelopeRequest(request: RequestChannel.Request, requestLocal: RequestLocal): CompletableFuture[Unit] = { if (!authHelper.authorize(request.context, CLUSTER_ACTION, CLUSTER, CLUSTER_NAME)) { requestHelper.sendErrorResponseMaybeThrottle(request, new ClusterAuthorizationException( s"Principal ${request.context.principal} does not have required CLUSTER_ACTION for envelope")) } else { EnvelopeUtils.handleEnvelopeRequest(request, requestChannel.metrics, handle(_, requestLocal)) } + CompletableFuture.completedFuture[Unit](()) } - def handleSaslHandshakeRequest(request: RequestChannel.Request): Unit = { + def handleSaslHandshakeRequest(request: RequestChannel.Request): CompletableFuture[Unit] = { val responseData = new SaslHandshakeResponseData().setErrorCode(ILLEGAL_SASL_STATE.code) requestHelper.sendResponseMaybeThrottle(request, _ => new SaslHandshakeResponse(responseData)) + CompletableFuture.completedFuture[Unit](()) } - def handleSaslAuthenticateRequest(request: RequestChannel.Request): Unit = { + def handleSaslAuthenticateRequest(request: RequestChannel.Request): CompletableFuture[Unit] = { val responseData = new SaslAuthenticateResponseData() .setErrorCode(ILLEGAL_SASL_STATE.code) .setErrorMessage("SaslAuthenticate request received after successful authentication") requestHelper.sendResponseMaybeThrottle(request, _ => new SaslAuthenticateResponse(responseData)) + CompletableFuture.completedFuture[Unit](()) } - def handleFetch(request: RequestChannel.Request): Unit = { + def handleFetch(request: RequestChannel.Request): CompletableFuture[Unit] = { authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) handleRaftRequest(request, response => new FetchResponse(response.asInstanceOf[FetchResponseData])) } - def handleFetchSnapshot(request: RequestChannel.Request): Unit = { + def handleFetchSnapshot(request: RequestChannel.Request): CompletableFuture[Unit] = { authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) handleRaftRequest(request, response => new FetchSnapshotResponse(response.asInstanceOf[FetchSnapshotResponseData])) } - def handleDeleteTopics(request: RequestChannel.Request): Unit = { + def handleDeleteTopics(request: RequestChannel.Request): CompletableFuture[Unit] = { val deleteTopicsRequest = request.body[DeleteTopicsRequest] val context = new ControllerRequestContext(request.context.header.data, request.context.principal, requestTimeoutMsToDeadlineNs(time, deleteTopicsRequest.data.timeoutMs)) @@ -166,7 +183,7 @@ class ControllerApis(val requestChannel: RequestChannel, authHelper.authorize(request.context, DELETE, CLUSTER, CLUSTER_NAME, logIfDenied = false), names => authHelper.filterByAuthorized(request.context, DESCRIBE, TOPIC, names)(n => n), names => authHelper.filterByAuthorized(request.context, DELETE, TOPIC, names)(n => n)) - future.whenComplete { (results, exception) => + future.handle[Unit] { (results, exception) => requestHelper.sendResponseMaybeThrottle(request, throttleTimeMs => { if (exception != null) { deleteTopicsRequest.getErrorResponse(throttleTimeMs, exception) @@ -320,7 +337,7 @@ class ControllerApis(val requestChannel: RequestChannel, } } - def handleCreateTopics(request: RequestChannel.Request): Unit = { + def handleCreateTopics(request: RequestChannel.Request): CompletableFuture[Unit] = { val createTopicsRequest = request.body[CreateTopicsRequest] val context = new ControllerRequestContext(request.context.header.data, request.context.principal, requestTimeoutMsToDeadlineNs(time, createTopicsRequest.data.timeoutMs)) @@ -330,7 +347,7 @@ class ControllerApis(val requestChannel: RequestChannel, names => authHelper.filterByAuthorized(request.context, CREATE, TOPIC, names)(identity), names => authHelper.filterByAuthorized(request.context, DESCRIBE_CONFIGS, TOPIC, names, logIfDenied = false)(identity)) - future.whenComplete { (result, exception) => + future.handle[Unit] { (result, exception) => requestHelper.sendResponseMaybeThrottle(request, throttleTimeMs => { if (exception != null) { createTopicsRequest.getErrorResponse(throttleTimeMs, exception) @@ -392,7 +409,7 @@ class ControllerApis(val requestChannel: RequestChannel, } } - def handleApiVersionsRequest(request: RequestChannel.Request): Unit = { + def handleApiVersionsRequest(request: RequestChannel.Request): CompletableFuture[Unit] = { // Note that broker returns its full list of supported ApiKeys and versions regardless of current // authentication state (e.g., before SASL authentication on an SASL listener, do note that no // Kafka protocol requests may take place on an SSL listener before the SSL handshake is finished). @@ -410,6 +427,7 @@ class ControllerApis(val requestChannel: RequestChannel, } } requestHelper.sendResponseMaybeThrottle(request, createResponseCallback) + CompletableFuture.completedFuture[Unit](()) } def authorizeAlterResource(requestContext: RequestContext, @@ -431,7 +449,7 @@ class ControllerApis(val requestChannel: RequestChannel, } } - def handleLegacyAlterConfigs(request: RequestChannel.Request): Unit = { + def handleLegacyAlterConfigs(request: RequestChannel.Request): CompletableFuture[Unit] = { val response = new AlterConfigsResponseData() val alterConfigsRequest = request.body[AlterConfigsRequest] val context = new ControllerRequestContext(request.context.header.data, request.context.principal, OptionalLong.empty()) @@ -474,7 +492,7 @@ class ControllerApis(val requestChannel: RequestChannel, } } controller.legacyAlterConfigs(context, configChanges, alterConfigsRequest.data.validateOnly) - .whenComplete { (controllerResults, exception) => + .handle[Unit] { (controllerResults, exception) => if (exception != null) { requestHelper.handleError(request, exception) } else { @@ -490,33 +508,33 @@ class ControllerApis(val requestChannel: RequestChannel, } } - def handleVote(request: RequestChannel.Request): Unit = { + def handleVote(request: RequestChannel.Request): CompletableFuture[Unit] = { authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) handleRaftRequest(request, response => new VoteResponse(response.asInstanceOf[VoteResponseData])) } - def handleBeginQuorumEpoch(request: RequestChannel.Request): Unit = { + def handleBeginQuorumEpoch(request: RequestChannel.Request): CompletableFuture[Unit] = { authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) handleRaftRequest(request, response => new BeginQuorumEpochResponse(response.asInstanceOf[BeginQuorumEpochResponseData])) } - def handleEndQuorumEpoch(request: RequestChannel.Request): Unit = { + def handleEndQuorumEpoch(request: RequestChannel.Request): CompletableFuture[Unit] = { authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) handleRaftRequest(request, response => new EndQuorumEpochResponse(response.asInstanceOf[EndQuorumEpochResponseData])) } - def handleDescribeQuorum(request: RequestChannel.Request): Unit = { + def handleDescribeQuorum(request: RequestChannel.Request): CompletableFuture[Unit] = { authHelper.authorizeClusterOperation(request, DESCRIBE) handleRaftRequest(request, response => new DescribeQuorumResponse(response.asInstanceOf[DescribeQuorumResponseData])) } - def handleElectLeaders(request: RequestChannel.Request): Unit = { + def handleElectLeaders(request: RequestChannel.Request): CompletableFuture[Unit] = { authHelper.authorizeClusterOperation(request, ALTER) val electLeadersRequest = request.body[ElectLeadersRequest] val context = new ControllerRequestContext(request.context.header.data, request.context.principal, requestTimeoutMsToDeadlineNs(time, electLeadersRequest.data.timeoutMs)) val future = controller.electLeaders(context, electLeadersRequest.data) - future.whenComplete { (responseData, exception) => + future.handle[Unit] { (responseData, exception) => if (exception != null) { requestHelper.sendResponseMaybeThrottle(request, throttleMs => { electLeadersRequest.getErrorResponse(throttleMs, exception) @@ -529,13 +547,13 @@ class ControllerApis(val requestChannel: RequestChannel, } } - def handleAlterPartitionRequest(request: RequestChannel.Request): Unit = { + def handleAlterPartitionRequest(request: RequestChannel.Request): CompletableFuture[Unit] = { val alterPartitionRequest = request.body[AlterPartitionRequest] val context = new ControllerRequestContext(request.context.header.data, request.context.principal, OptionalLong.empty()) authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) val future = controller.alterPartition(context, alterPartitionRequest.data) - future.whenComplete { (result, exception) => + future.handle[Unit] { (result, exception) => val response = if (exception != null) { alterPartitionRequest.getErrorResponse(exception) } else { @@ -545,7 +563,7 @@ class ControllerApis(val requestChannel: RequestChannel, } } - def handleBrokerHeartBeatRequest(request: RequestChannel.Request): Unit = { + def handleBrokerHeartBeatRequest(request: RequestChannel.Request): CompletableFuture[Unit] = { val heartbeatRequest = request.body[BrokerHeartbeatRequest] authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) val context = new ControllerRequestContext(request.context.header.data, request.context.principal, @@ -572,7 +590,7 @@ class ControllerApis(val requestChannel: RequestChannel, } } - def handleUnregisterBroker(request: RequestChannel.Request): Unit = { + def handleUnregisterBroker(request: RequestChannel.Request): CompletableFuture[Unit] = { val decommissionRequest = request.body[UnregisterBrokerRequest] authHelper.authorizeClusterOperation(request, ALTER) val context = new ControllerRequestContext(request.context.header.data, request.context.principal, @@ -595,7 +613,7 @@ class ControllerApis(val requestChannel: RequestChannel, } } - def handleBrokerRegistration(request: RequestChannel.Request): Unit = { + def handleBrokerRegistration(request: RequestChannel.Request): CompletableFuture[Unit] = { val registrationRequest = request.body[BrokerRegistrationRequest] authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) val context = new ControllerRequestContext(request.context.header.data, request.context.principal, @@ -622,11 +640,10 @@ class ControllerApis(val requestChannel: RequestChannel, } private def handleRaftRequest(request: RequestChannel.Request, - buildResponse: ApiMessage => AbstractResponse): Unit = { + buildResponse: ApiMessage => AbstractResponse): CompletableFuture[Unit] = { val requestBody = request.body[AbstractRequest] val future = raftManager.handleRequest(request.header, requestBody.data, time.milliseconds()) - - future.whenComplete { (responseData, exception) => + future.handle[Unit] { (responseData, exception) => val response = if (exception != null) { requestBody.getErrorResponse(exception) } else { @@ -636,13 +653,13 @@ class ControllerApis(val requestChannel: RequestChannel, } } - def handleAlterClientQuotas(request: RequestChannel.Request): Unit = { + def handleAlterClientQuotas(request: RequestChannel.Request): CompletableFuture[Unit] = { val quotaRequest = request.body[AlterClientQuotasRequest] authHelper.authorizeClusterOperation(request, ALTER_CONFIGS) val context = new ControllerRequestContext(request.context.header.data, request.context.principal, OptionalLong.empty()) controller.alterClientQuotas(context, quotaRequest.entries, quotaRequest.validateOnly) - .whenComplete { (results, exception) => + .handle[Unit] { (results, exception) => if (exception != null) { requestHelper.handleError(request, exception) } else { @@ -652,7 +669,7 @@ class ControllerApis(val requestChannel: RequestChannel, } } - def handleIncrementalAlterConfigs(request: RequestChannel.Request): Unit = { + def handleIncrementalAlterConfigs(request: RequestChannel.Request): CompletableFuture[Unit] = { val response = new IncrementalAlterConfigsResponseData() val alterConfigsRequest = request.body[IncrementalAlterConfigsRequest] val context = new ControllerRequestContext(request.context.header.data, request.context.principal, @@ -700,7 +717,7 @@ class ControllerApis(val requestChannel: RequestChannel, } } controller.incrementalAlterConfigs(context, configChanges, alterConfigsRequest.data.validateOnly) - .whenComplete { (controllerResults, exception) => + .handle[Unit] { (controllerResults, exception) => if (exception != null) { requestHelper.handleError(request, exception) } else { @@ -716,7 +733,7 @@ class ControllerApis(val requestChannel: RequestChannel, } } - def handleCreatePartitions(request: RequestChannel.Request): Unit = { + def handleCreatePartitions(request: RequestChannel.Request): CompletableFuture[Unit] = { def filterAlterAuthorizedTopics(topics: Iterable[String]): Set[String] = { authHelper.filterByAuthorized(request.context, ALTER, TOPIC, topics)(n => n) } @@ -726,7 +743,7 @@ class ControllerApis(val requestChannel: RequestChannel, val future = createPartitions(context, createPartitionsRequest.data(), filterAlterAuthorizedTopics) - future.whenComplete { (responses, exception) => + future.handle[Unit] { (responses, exception) => if (exception != null) { requestHelper.handleError(request, exception) } else { @@ -778,33 +795,37 @@ class ControllerApis(val requestChannel: RequestChannel, } } - def handleAlterPartitionReassignments(request: RequestChannel.Request): Unit = { + def handleAlterPartitionReassignments(request: RequestChannel.Request): CompletableFuture[Unit] = { val alterRequest = request.body[AlterPartitionReassignmentsRequest] authHelper.authorizeClusterOperation(request, ALTER) val context = new ControllerRequestContext(request.context.header.data, request.context.principal, requestTimeoutMsToDeadlineNs(time, alterRequest.data.timeoutMs)) - val response = controller.alterPartitionReassignments(context, alterRequest.data).get() - requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => - new AlterPartitionReassignmentsResponse(response.setThrottleTimeMs(requestThrottleMs))) + controller.alterPartitionReassignments(context, alterRequest.data) + .thenApply[Unit] { response => + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new AlterPartitionReassignmentsResponse(response.setThrottleTimeMs(requestThrottleMs))) + } } - def handleListPartitionReassignments(request: RequestChannel.Request): Unit = { + def handleListPartitionReassignments(request: RequestChannel.Request): CompletableFuture[Unit] = { val listRequest = request.body[ListPartitionReassignmentsRequest] authHelper.authorizeClusterOperation(request, DESCRIBE) val context = new ControllerRequestContext(request.context.header.data, request.context.principal, OptionalLong.empty()) - val response = controller.listPartitionReassignments(context, listRequest.data).get() - requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => - new ListPartitionReassignmentsResponse(response.setThrottleTimeMs(requestThrottleMs))) + controller.listPartitionReassignments(context, listRequest.data) + .thenApply[Unit] { response => + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new ListPartitionReassignmentsResponse(response.setThrottleTimeMs(requestThrottleMs))) + } } - def handleAllocateProducerIdsRequest(request: RequestChannel.Request): Unit = { + def handleAllocateProducerIdsRequest(request: RequestChannel.Request): CompletableFuture[Unit] = { val allocatedProducerIdsRequest = request.body[AllocateProducerIdsRequest] authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) val context = new ControllerRequestContext(request.context.header.data, request.context.principal, OptionalLong.empty()) controller.allocateProducerIds(context, allocatedProducerIdsRequest.data) - .whenComplete((results, exception) => { + .handle[Unit] { (results, exception) => if (exception != null) { requestHelper.handleError(request, exception) } else { @@ -813,22 +834,22 @@ class ControllerApis(val requestChannel: RequestChannel, new AllocateProducerIdsResponse(results) }) } - }) + } } - def handleUpdateFeatures(request: RequestChannel.Request): Unit = { + def handleUpdateFeatures(request: RequestChannel.Request): CompletableFuture[Unit] = { val updateFeaturesRequest = request.body[UpdateFeaturesRequest] authHelper.authorizeClusterOperation(request, ALTER) val context = new ControllerRequestContext(request.context.header.data, request.context.principal, OptionalLong.empty()) controller.updateFeatures(context, updateFeaturesRequest.data) - .whenComplete((response, exception) => { + .handle[Unit] { (response, exception) => if (exception != null) { requestHelper.handleError(request, exception) } else { requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => new UpdateFeaturesResponse(response.setThrottleTimeMs(requestThrottleMs))) } - }) + } } } diff --git a/core/src/test/scala/unit/kafka/server/ControllerApisTest.scala b/core/src/test/scala/unit/kafka/server/ControllerApisTest.scala index 86a0c854705..0fc96114527 100644 --- a/core/src/test/scala/unit/kafka/server/ControllerApisTest.scala +++ b/core/src/test/scala/unit/kafka/server/ControllerApisTest.scala @@ -63,6 +63,7 @@ import java.net.InetAddress import java.util import java.util.Collections.singletonList import java.util.concurrent.{CompletableFuture, ExecutionException, TimeUnit} +import java.util.concurrent.atomic.AtomicReference import java.util.{Collections, Properties} import scala.annotation.nowarn import scala.jdk.CollectionConverters._ @@ -902,6 +903,35 @@ class ControllerApisTest { } } + @Test + def testCompletableFutureExceptions(): Unit = { + // This test simulates an error in a completable future as we return from the controller. We need to ensure + // that any exception throw in the completion phase is properly captured and translated to an error response. + val request = buildRequest(new FetchRequest(new FetchRequestData(), 12)) + val response = new FetchResponseData() + val responseFuture = new CompletableFuture[ApiMessage]() + val errorResponseFuture = new AtomicReference[AbstractResponse]() + when(raftManager.handleRequest(any(), any(), any())).thenReturn(responseFuture) + when(requestChannel.sendResponse(any(), any(), any())).thenAnswer { _ => + // Simulate an encoding failure in the initial fetch response + throw new UnsupportedVersionException("Something went wrong") + }.thenAnswer { invocation => + val resp = invocation.getArgument(1, classOf[AbstractResponse]) + errorResponseFuture.set(resp) + } + + // Calling handle does not block since we do not call get() in ControllerApis + createControllerApis(None, + new MockController.Builder().build()).handle(request, null) + + // When we complete this future, the completion stages will fire (including the error handler in ControllerApis#request) + responseFuture.complete(response) + + // Now we should get an error response with UNSUPPORTED_VERSION + val errorResponse = errorResponseFuture.get() + assertEquals(1, errorResponse.errorCounts().getOrDefault(Errors.UNSUPPORTED_VERSION, 0)) + } + @AfterEach def tearDown(): Unit = { quotas.shutdown()
