hachikuji commented on a change in pull request #9732: URL: https://github.com/apache/kafka/pull/9732#discussion_r541219622
########## File path: core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala ########## @@ -68,179 +51,93 @@ object KafkaNetworkChannel { } } - private[raft] def responseData(response: AbstractResponse): ApiMessage = { - response match { - case voteResponse: VoteResponse => voteResponse.data - case beginEpochResponse: BeginQuorumEpochResponse => beginEpochResponse.data - case endEpochResponse: EndQuorumEpochResponse => endEpochResponse.data - case fetchResponse: FetchResponse[_] => fetchResponse.data - case _ => throw new IllegalArgumentException(s"Unexpected type for response: $response") - } - } - - private[raft] def requestData(request: AbstractRequest): ApiMessage = { - request match { - case voteRequest: VoteRequest => voteRequest.data - case beginEpochRequest: BeginQuorumEpochRequest => beginEpochRequest.data - case endEpochRequest: EndQuorumEpochRequest => endEpochRequest.data - case fetchRequest: FetchRequest => fetchRequest.data - case _ => throw new IllegalArgumentException(s"Unexpected type for request: $request") - } - } - } -class KafkaNetworkChannel(time: Time, - client: KafkaClient, - clientId: String, - retryBackoffMs: Int, - requestTimeoutMs: Int) extends NetworkChannel with Logging { +class KafkaNetworkChannel( + time: Time, + client: KafkaClient, + requestTimeoutMs: Int +) extends NetworkChannel with Logging { import KafkaNetworkChannel._ type ResponseHandler = AbstractResponse => Unit private val correlationIdCounter = new AtomicInteger(0) - private val pendingInbound = mutable.Map.empty[Long, ResponseHandler] - private val undelivered = new ArrayBlockingQueue[RaftMessage](10) - private val pendingOutbound = new ArrayBlockingQueue[RaftRequest.Outbound](10) private val endpoints = mutable.HashMap.empty[Int, Node] - override def newCorrelationId(): Int = correlationIdCounter.getAndIncrement() - - private def buildClientRequest(req: RaftRequest.Outbound): ClientRequest = { - val destination = req.destinationId.toString - val request = buildRequest(req.data) - val correlationId = req.correlationId - val createdTimeMs = req.createdTimeMs - new ClientRequest(destination, request, correlationId, clientId, createdTimeMs, true, - requestTimeoutMs, null) - } - - override def send(message: RaftMessage): Unit = { - message match { - case request: RaftRequest.Outbound => - if (!pendingOutbound.offer(request)) - throw new KafkaException("Pending outbound queue is full") - - case response: RaftResponse.Outbound => - pendingInbound.remove(response.correlationId).foreach { onResponseReceived: ResponseHandler => - onResponseReceived(buildResponse(response.data)) - } - case _ => - throw new IllegalArgumentException("Unhandled message type " + message) + private val requestThread = new InterBrokerSendThread( + name = "raft-outbound-request-thread", + networkClient = client, + requestTimeoutMs = requestTimeoutMs, + time = time, + isInterruptible = false + ) + + override def send(request: RaftRequest.Outbound): Unit = { + def completeFuture(message: ApiMessage): Unit = { + val response = new RaftResponse.Inbound( + request.correlationId, + message, + request.destinationId + ) + request.completion.complete(response) } - } - private def sendOutboundRequests(currentTimeMs: Long): Unit = { - while (!pendingOutbound.isEmpty) { - val request = pendingOutbound.peek() - endpoints.get(request.destinationId) match { - case Some(node) => - if (client.connectionFailed(node)) { - pendingOutbound.poll() - val apiKey = ApiKeys.forId(request.data.apiKey) - val disconnectResponse = RaftUtil.errorResponse(apiKey, Errors.BROKER_NOT_AVAILABLE) - val success = undelivered.offer(new RaftResponse.Inbound( - request.correlationId, disconnectResponse, request.destinationId)) - if (!success) { - throw new KafkaException("Undelivered queue is full") - } - - // Make sure to reset the connection state - client.ready(node, currentTimeMs) - } else if (client.ready(node, currentTimeMs)) { - pendingOutbound.poll() - val clientRequest = buildClientRequest(request) - client.send(clientRequest, currentTimeMs) - } else { - // We will retry this request on the next poll - return - } - - case None => - pendingOutbound.poll() - val apiKey = ApiKeys.forId(request.data.apiKey) - val responseData = RaftUtil.errorResponse(apiKey, Errors.BROKER_NOT_AVAILABLE) - val response = new RaftResponse.Inbound(request.correlationId, responseData, request.destinationId) - if (!undelivered.offer(response)) - throw new KafkaException("Undelivered queue is full") + def onComplete(clientResponse: ClientResponse): Unit = { + val response = if (clientResponse.authenticationException != null) { + errorResponse(request.data, Errors.CLUSTER_AUTHORIZATION_FAILED) + } else if (clientResponse.wasDisconnected()) { + errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE) + } else { + clientResponse.responseBody.data } + completeFuture(response) } - } - - def getConnectionInfo(nodeId: Int): Node = { - if (!endpoints.contains(nodeId)) - null - else - endpoints(nodeId) - } - - def allConnections(): Set[Node] = { - endpoints.values.toSet - } - private def buildInboundRaftResponse(response: ClientResponse): RaftResponse.Inbound = { - val header = response.requestHeader() - val data = if (response.authenticationException != null) { - RaftUtil.errorResponse(header.apiKey, Errors.CLUSTER_AUTHORIZATION_FAILED) - } else if (response.wasDisconnected) { - RaftUtil.errorResponse(header.apiKey, Errors.BROKER_NOT_AVAILABLE) - } else { - responseData(response.responseBody) - } - new RaftResponse.Inbound(header.correlationId, data, response.destination.toInt) - } + endpoints.get(request.destinationId) match { + case Some(node) => + requestThread.sendRequest(RequestAndCompletionHandler( + destination = node, + request = buildRequest(request.data), + handler = onComplete + )) - private def pollInboundResponses(timeoutMs: Long, inboundMessages: util.List[RaftMessage]): Unit = { - val responses = client.poll(timeoutMs, time.milliseconds()) - for (response <- responses.asScala) { - inboundMessages.add(buildInboundRaftResponse(response)) + case None => + completeFuture(errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE)) } } - private def drainInboundRequests(inboundMessages: util.List[RaftMessage]): Unit = { - undelivered.drainTo(inboundMessages) + def pollOnce(): Unit = { Review comment: Yeah, let me make that clearer. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org