This is an automated email from the ASF dual-hosted git repository.

ijuma pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 69e591db3a7 MINOR: Rewrite/Move KafkaNetworkChannel to the `raft` 
module (#14559)
69e591db3a7 is described below

commit 69e591db3a7329a8bb984f068596d8658a8618b3
Author: Ismael Juma <[email protected]>
AuthorDate: Mon Oct 16 20:10:31 2023 -0700

    MINOR: Rewrite/Move KafkaNetworkChannel to the `raft` module (#14559)
    
    This is now possible since `InterBrokerSend` was moved from `core` to 
`server-common`.
    Also rewrite/move `KafkaNetworkChannelTest`.
    
    The scala version of `KafkaNetworkChannelTest` passed with the changes here 
(before I
    deleted it).
    
    Reviewers: Justine Olshan <[email protected]>, José Armando García 
Sancio <[email protected]>
---
 checkstyle/import-control.xml                      |   1 +
 .../scala/kafka/raft/KafkaNetworkChannel.scala     | 191 ------------
 core/src/main/scala/kafka/raft/RaftManager.scala   |   2 +-
 .../unit/kafka/raft/KafkaNetworkChannelTest.scala  | 316 --------------------
 .../org/apache/kafka/raft/KafkaNetworkChannel.java | 183 ++++++++++++
 .../java/org/apache/kafka/raft/NetworkChannel.java |   6 +-
 .../apache/kafka/raft/KafkaNetworkChannelTest.java | 323 +++++++++++++++++++++
 7 files changed, 510 insertions(+), 512 deletions(-)

diff --git a/checkstyle/import-control.xml b/checkstyle/import-control.xml
index 888e8a41ae8..42488c3225f 100644
--- a/checkstyle/import-control.xml
+++ b/checkstyle/import-control.xml
@@ -406,6 +406,7 @@
     <allow pkg="org.apache.kafka.common.protocol" />
     <allow pkg="org.apache.kafka.server.common" />
     <allow pkg="org.apache.kafka.server.common.serialization" />
+    <allow pkg="org.apache.kafka.server.util" />
     <allow pkg="org.apache.kafka.test"/>
     <allow pkg="com.fasterxml.jackson" />
     <allow pkg="net.jqwik"/>
diff --git a/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala 
b/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala
deleted file mode 100644
index 7c00961d1dc..00000000000
--- a/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala
+++ /dev/null
@@ -1,191 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package kafka.raft
-
-import kafka.utils.Logging
-import org.apache.kafka.clients.{ClientResponse, KafkaClient}
-import org.apache.kafka.common.Node
-import org.apache.kafka.common.message._
-import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors}
-import org.apache.kafka.common.requests._
-import org.apache.kafka.common.utils.Time
-import org.apache.kafka.raft.RaftConfig.InetAddressSpec
-import org.apache.kafka.raft.{NetworkChannel, RaftRequest, RaftResponse, 
RaftUtil}
-import org.apache.kafka.server.util.{InterBrokerSendThread, 
RequestAndCompletionHandler}
-
-import java.util
-import java.util.concurrent.ConcurrentLinkedQueue
-import java.util.concurrent.atomic.AtomicInteger
-import scala.collection.mutable
-
-object KafkaNetworkChannel {
-
-  private[raft] def buildRequest(requestData: ApiMessage): 
AbstractRequest.Builder[_ <: AbstractRequest] = {
-    requestData match {
-      case voteRequest: VoteRequestData =>
-        new VoteRequest.Builder(voteRequest)
-      case beginEpochRequest: BeginQuorumEpochRequestData =>
-        new BeginQuorumEpochRequest.Builder(beginEpochRequest)
-      case endEpochRequest: EndQuorumEpochRequestData =>
-        new EndQuorumEpochRequest.Builder(endEpochRequest)
-      case fetchRequest: FetchRequestData =>
-        new FetchRequest.SimpleBuilder(fetchRequest)
-      case fetchSnapshotRequest: FetchSnapshotRequestData =>
-        new FetchSnapshotRequest.Builder(fetchSnapshotRequest)
-      case _ =>
-        throw new IllegalArgumentException(s"Unexpected type for requestData: 
$requestData")
-    }
-  }
-
-}
-
-private[raft] class RaftSendThread(
-  name: String,
-  networkClient: KafkaClient,
-  requestTimeoutMs: Int,
-  time: Time,
-  isInterruptible: Boolean = true
-) extends InterBrokerSendThread(
-  name,
-  networkClient,
-  requestTimeoutMs,
-  time,
-  isInterruptible
-) {
-  private val queue = new ConcurrentLinkedQueue[RequestAndCompletionHandler]()
-
-  def generateRequests(): util.Collection[RequestAndCompletionHandler] = {
-    val list =  new util.ArrayList[RequestAndCompletionHandler]()
-    while (true) {
-      val request = queue.poll()
-      if (request == null) {
-        return list
-      } else {
-        list.add(request)
-      }
-    }
-    list
-  }
-
-  def sendRequest(request: RequestAndCompletionHandler): Unit = {
-    queue.add(request)
-    wakeup()
-  }
-
-}
-
-
-class KafkaNetworkChannel(
-  time: Time,
-  client: KafkaClient,
-  requestTimeoutMs: Int,
-  threadNamePrefix: String
-) extends NetworkChannel with Logging {
-  import KafkaNetworkChannel._
-
-  type ResponseHandler = AbstractResponse => Unit
-
-  private val correlationIdCounter = new AtomicInteger(0)
-  private val endpoints = mutable.HashMap.empty[Int, Node]
-
-  private val requestThread = new RaftSendThread(
-    name = threadNamePrefix + "-outbound-request-thread",
-    networkClient = client,
-    requestTimeoutMs = requestTimeoutMs,
-    time = time,
-    isInterruptible = false
-  )
-
-  override def send(request: RaftRequest.Outbound): Unit = {
-    def completeFuture(message: ApiMessage): Unit = {
-      val response = new RaftResponse.Inbound(
-        request.correlationId,
-        message,
-        request.destinationId
-      )
-      request.completion.complete(response)
-    }
-
-    def onComplete(clientResponse: ClientResponse): Unit = {
-      val response = if (clientResponse.versionMismatch != null) {
-        error(s"Request $request failed due to unsupported version error",
-          clientResponse.versionMismatch)
-        errorResponse(request.data, Errors.UNSUPPORTED_VERSION)
-      } else if (clientResponse.authenticationException != null) {
-        // For now we treat authentication errors as retriable. We use the
-        // `NETWORK_EXCEPTION` error code for lack of a good alternative.
-        // Note that `NodeToControllerChannelManager` will still log the
-        // authentication errors so that users have a chance to fix the 
problem.
-        error(s"Request $request failed due to authentication error",
-          clientResponse.authenticationException)
-        errorResponse(request.data, Errors.NETWORK_EXCEPTION)
-      } else if (clientResponse.wasDisconnected()) {
-        errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE)
-      } else {
-        clientResponse.responseBody.data
-      }
-      completeFuture(response)
-    }
-
-    endpoints.get(request.destinationId) match {
-      case Some(node) =>
-        requestThread.sendRequest(new RequestAndCompletionHandler(
-          request.createdTimeMs,
-          node,
-          buildRequest(request.data),
-          onComplete
-        ))
-
-      case None =>
-        completeFuture(errorResponse(request.data, 
Errors.BROKER_NOT_AVAILABLE))
-    }
-  }
-
-  // Visible for testing
-  private[raft] def pollOnce(): Unit = {
-    requestThread.doWork()
-  }
-
-  override def newCorrelationId(): Int = {
-    correlationIdCounter.getAndIncrement()
-  }
-
-  private def errorResponse(
-    request: ApiMessage,
-    error: Errors
-  ): ApiMessage = {
-    val apiKey = ApiKeys.forId(request.apiKey)
-    RaftUtil.errorResponse(apiKey, error)
-  }
-
-  override def updateEndpoint(id: Int, spec: InetAddressSpec): Unit = {
-    val node = new Node(id, spec.address.getHostString, spec.address.getPort)
-    endpoints.put(id, node)
-  }
-
-  def start(): Unit = {
-    requestThread.start()
-  }
-
-  def initiateShutdown(): Unit = {
-    requestThread.initiateShutdown()
-  }
-
-  override def close(): Unit = {
-    requestThread.shutdown()
-  }
-}
diff --git a/core/src/main/scala/kafka/raft/RaftManager.scala 
b/core/src/main/scala/kafka/raft/RaftManager.scala
index 020477d5a42..f9311d20d95 100644
--- a/core/src/main/scala/kafka/raft/RaftManager.scala
+++ b/core/src/main/scala/kafka/raft/RaftManager.scala
@@ -42,7 +42,7 @@ import org.apache.kafka.common.security.JaasContext
 import org.apache.kafka.common.security.auth.SecurityProtocol
 import org.apache.kafka.common.utils.{LogContext, Time}
 import org.apache.kafka.raft.RaftConfig.{AddressSpec, InetAddressSpec, 
NON_ROUTABLE_ADDRESS, UnknownAddressSpec}
-import org.apache.kafka.raft.{FileBasedStateStore, KafkaRaftClient, 
LeaderAndEpoch, RaftClient, RaftConfig, RaftRequest, ReplicatedLog}
+import org.apache.kafka.raft.{FileBasedStateStore, KafkaNetworkChannel, 
KafkaRaftClient, LeaderAndEpoch, RaftClient, RaftConfig, RaftRequest, 
ReplicatedLog}
 import org.apache.kafka.server.common.serialization.RecordSerde
 import org.apache.kafka.server.util.{KafkaScheduler, ShutdownableThread}
 import org.apache.kafka.server.fault.FaultHandler
diff --git a/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala 
b/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala
deleted file mode 100644
index af230f66553..00000000000
--- a/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala
+++ /dev/null
@@ -1,316 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package kafka.raft
-
-import java.net.InetSocketAddress
-import java.util
-import java.util.Collections
-import org.apache.kafka.clients.MockClient.MockMetadataUpdater
-import org.apache.kafka.clients.{MockClient, NodeApiVersions}
-import org.apache.kafka.common.message.FetchRequestData.ReplicaState
-import org.apache.kafka.common.message.{BeginQuorumEpochResponseData, 
EndQuorumEpochResponseData, FetchResponseData, VoteResponseData}
-import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors}
-import org.apache.kafka.common.requests.{AbstractResponse, 
ApiVersionsResponse, BeginQuorumEpochRequest, BeginQuorumEpochResponse, 
EndQuorumEpochRequest, EndQuorumEpochResponse, FetchRequest, FetchResponse, 
VoteRequest, VoteResponse}
-import org.apache.kafka.common.utils.annotation.ApiKeyVersionsSource
-import org.apache.kafka.common.utils.{MockTime, Time}
-import org.apache.kafka.common.{Node, TopicPartition, Uuid}
-import org.apache.kafka.raft.RaftConfig.InetAddressSpec
-import org.apache.kafka.raft.{RaftRequest, RaftUtil}
-import org.junit.jupiter.api.Assertions._
-import org.junit.jupiter.api.{BeforeEach, Test}
-import org.junit.jupiter.params.ParameterizedTest
-
-import scala.jdk.CollectionConverters._
-
-class KafkaNetworkChannelTest {
-  import KafkaNetworkChannelTest._
-
-  private val clusterId = "clusterId"
-  private val requestTimeoutMs = 30000
-  private val time = new MockTime()
-  private val client = new MockClient(time, new StubMetadataUpdater)
-  private val topicPartition = new TopicPartition("topic", 0)
-  private val topicId = Uuid.randomUuid()
-  private val channel = new KafkaNetworkChannel(time, client, 
requestTimeoutMs, threadNamePrefix = "test-raft")
-
-  @BeforeEach
-  def setupSupportedApis(): Unit = {
-    val supportedApis = RaftApis.map(ApiVersionsResponse.toApiVersion)
-    client.setNodeApiVersions(NodeApiVersions.create(supportedApis.asJava))
-  }
-
-  @Test
-  def testSendToUnknownDestination(): Unit = {
-    val destinationId = 2
-    assertBrokerNotAvailable(destinationId)
-  }
-
-  @Test
-  def testSendToBlackedOutDestination(): Unit = {
-    val destinationId = 2
-    val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
-    channel.updateEndpoint(destinationId, new InetAddressSpec(
-      new InetSocketAddress(destinationNode.host, destinationNode.port)))
-    client.backoff(destinationNode, 500)
-    assertBrokerNotAvailable(destinationId)
-  }
-
-  @Test
-  def testWakeupClientOnSend(): Unit = {
-    val destinationId = 2
-    val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
-    channel.updateEndpoint(destinationId, new InetAddressSpec(
-      new InetSocketAddress(destinationNode.host, destinationNode.port)))
-
-    client.enableBlockingUntilWakeup(1)
-
-    val ioThread = new Thread() {
-      override def run(): Unit = {
-        // Block in poll until we get the expected wakeup
-        channel.pollOnce()
-
-        // Poll a second time to send request and receive response
-        channel.pollOnce()
-      }
-    }
-
-    val response = buildResponse(buildTestErrorResponse(ApiKeys.FETCH, 
Errors.INVALID_REQUEST))
-    client.prepareResponseFrom(response, destinationNode, false)
-
-    ioThread.start()
-    val request = sendTestRequest(ApiKeys.FETCH, destinationId)
-
-    ioThread.join()
-    assertResponseCompleted(request, Errors.INVALID_REQUEST)
-  }
-
-  @Test
-  def testSendAndDisconnect(): Unit = {
-    val destinationId = 2
-    val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
-    channel.updateEndpoint(destinationId, new InetAddressSpec(
-      new InetSocketAddress(destinationNode.host, destinationNode.port)))
-
-    for (apiKey <- RaftApis) {
-      val response = buildResponse(buildTestErrorResponse(apiKey, 
Errors.INVALID_REQUEST))
-      client.prepareResponseFrom(response, destinationNode, true)
-      sendAndAssertErrorResponse(apiKey, destinationId, 
Errors.BROKER_NOT_AVAILABLE)
-    }
-  }
-
-  @Test
-  def testSendAndFailAuthentication(): Unit = {
-    val destinationId = 2
-    val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
-    channel.updateEndpoint(destinationId, new InetAddressSpec(
-      new InetSocketAddress(destinationNode.host, destinationNode.port)))
-
-    for (apiKey <- RaftApis) {
-      client.createPendingAuthenticationError(destinationNode, 100)
-      sendAndAssertErrorResponse(apiKey, destinationId, 
Errors.NETWORK_EXCEPTION)
-
-      // reset to clear backoff time
-      client.reset()
-    }
-  }
-
-  private def assertBrokerNotAvailable(destinationId: Int): Unit = {
-    for (apiKey <- RaftApis) {
-      sendAndAssertErrorResponse(apiKey, destinationId, 
Errors.BROKER_NOT_AVAILABLE)
-    }
-  }
-
-  @Test
-  def testSendAndReceiveOutboundRequest(): Unit = {
-    val destinationId = 2
-    val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
-    channel.updateEndpoint(destinationId, new InetAddressSpec(
-      new InetSocketAddress(destinationNode.host, destinationNode.port)))
-
-    for (apiKey <- RaftApis) {
-      val expectedError = Errors.INVALID_REQUEST
-      val response = buildResponse(buildTestErrorResponse(apiKey, 
expectedError))
-      client.prepareResponseFrom(response, destinationNode)
-      sendAndAssertErrorResponse(apiKey, destinationId, expectedError)
-    }
-  }
-
-  @Test
-  def testUnsupportedVersionError(): Unit = {
-    val destinationId = 2
-    val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
-    channel.updateEndpoint(destinationId, new InetAddressSpec(
-      new InetSocketAddress(destinationNode.host, destinationNode.port)))
-
-    for (apiKey <- RaftApis) {
-      client.prepareUnsupportedVersionResponse(request => request.apiKey == 
apiKey)
-      sendAndAssertErrorResponse(apiKey, destinationId, 
Errors.UNSUPPORTED_VERSION)
-    }
-  }
-
-  @ParameterizedTest
-  @ApiKeyVersionsSource(apiKey = ApiKeys.FETCH)
-  def testFetchRequestDowngrade(version: Short): Unit = {
-    val destinationId = 2
-    val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
-    channel.updateEndpoint(destinationId, new InetAddressSpec(
-      new InetSocketAddress(destinationNode.host, destinationNode.port)))
-    sendTestRequest(ApiKeys.FETCH, destinationId)
-    channel.pollOnce()
-
-    assertEquals(1, client.requests.size)
-    val request = client.requests.peek.requestBuilder.build(version)
-
-    if (version < 15) {
-      assertTrue(request.asInstanceOf[FetchRequest].data.replicaId == 1)
-      
assertTrue(request.asInstanceOf[FetchRequest].data.replicaState.replicaId == -1)
-    } else {
-      assertTrue(request.asInstanceOf[FetchRequest].data.replicaId == -1)
-      
assertTrue(request.asInstanceOf[FetchRequest].data.replicaState.replicaId == 1)
-    }
-  }
-
-  private def sendTestRequest(
-    apiKey: ApiKeys,
-    destinationId: Int,
-  ): RaftRequest.Outbound = {
-    val correlationId = channel.newCorrelationId()
-    val createdTimeMs = time.milliseconds()
-    val apiRequest = buildTestRequest(apiKey)
-    val request = new RaftRequest.Outbound(correlationId, apiRequest, 
destinationId, createdTimeMs)
-    channel.send(request)
-    request
-  }
-
-  private def assertResponseCompleted(
-    request: RaftRequest.Outbound,
-    expectedError: Errors
-  ): Unit = {
-    assertTrue(request.completion.isDone)
-
-    val response = request.completion.get()
-    assertEquals(request.destinationId, response.sourceId)
-    assertEquals(request.correlationId, response.correlationId)
-    assertEquals(request.data.apiKey, response.data.apiKey)
-    assertEquals(expectedError, extractError(response.data))
-  }
-
-  private def sendAndAssertErrorResponse(
-    apiKey: ApiKeys,
-    destinationId: Int,
-    error: Errors
-  ): Unit = {
-    val request = sendTestRequest(apiKey, destinationId)
-    channel.pollOnce()
-    assertResponseCompleted(request, error)
-  }
-
-  private def buildTestRequest(key: ApiKeys): ApiMessage = {
-    val leaderEpoch = 5
-    val leaderId = 1
-    key match {
-      case ApiKeys.BEGIN_QUORUM_EPOCH =>
-        BeginQuorumEpochRequest.singletonRequest(topicPartition, clusterId, 
leaderEpoch, leaderId)
-
-      case ApiKeys.END_QUORUM_EPOCH =>
-        EndQuorumEpochRequest.singletonRequest(topicPartition, clusterId, 
leaderId,
-          leaderEpoch, Collections.singletonList(2))
-
-      case ApiKeys.VOTE =>
-        val lastEpoch = 4
-        VoteRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, 
leaderId, lastEpoch, 329)
-
-      case ApiKeys.FETCH =>
-        val request = RaftUtil.singletonFetchRequest(topicPartition, topicId, 
fetchPartition => {
-          fetchPartition
-            .setCurrentLeaderEpoch(5)
-            .setFetchOffset(333)
-            .setLastFetchedEpoch(5)
-        })
-        request.setReplicaState(new ReplicaState().setReplicaId(1))
-
-      case _ =>
-        throw new AssertionError(s"Unexpected api $key")
-    }
-  }
-
-  private def buildTestErrorResponse(key: ApiKeys, error: Errors): ApiMessage 
= {
-    key match {
-      case ApiKeys.BEGIN_QUORUM_EPOCH =>
-        new BeginQuorumEpochResponseData()
-          .setErrorCode(error.code)
-
-      case ApiKeys.END_QUORUM_EPOCH =>
-        new EndQuorumEpochResponseData()
-          .setErrorCode(error.code)
-
-      case ApiKeys.VOTE =>
-        VoteResponse.singletonResponse(error, topicPartition, Errors.NONE, 1, 
5, false);
-
-      case ApiKeys.FETCH =>
-        new FetchResponseData()
-          .setErrorCode(error.code)
-
-      case _ =>
-        throw new AssertionError(s"Unexpected api $key")
-    }
-  }
-
-  private def extractError(response: ApiMessage): Errors = {
-    val code = (response: @unchecked) match {
-      case res: BeginQuorumEpochResponseData => res.errorCode
-      case res: EndQuorumEpochResponseData => res.errorCode
-      case res: FetchResponseData => res.errorCode
-      case res: VoteResponseData => res.errorCode
-    }
-    Errors.forCode(code)
-  }
-
-
-  def buildResponse(responseData: ApiMessage): AbstractResponse = {
-    responseData match {
-      case voteResponse: VoteResponseData =>
-        new VoteResponse(voteResponse)
-      case beginEpochResponse: BeginQuorumEpochResponseData =>
-        new BeginQuorumEpochResponse(beginEpochResponse)
-      case endEpochResponse: EndQuorumEpochResponseData =>
-        new EndQuorumEpochResponse(endEpochResponse)
-      case fetchResponse: FetchResponseData =>
-        new FetchResponse(fetchResponse)
-      case _ =>
-        throw new IllegalArgumentException(s"Unexpected type for responseData: 
$responseData")
-    }
-  }
-
-}
-
-object KafkaNetworkChannelTest {
-  val RaftApis = Seq(
-    ApiKeys.VOTE,
-    ApiKeys.BEGIN_QUORUM_EPOCH,
-    ApiKeys.END_QUORUM_EPOCH,
-    ApiKeys.FETCH,
-  )
-
-  private class StubMetadataUpdater extends MockMetadataUpdater {
-    override def fetchNodes(): util.List[Node] = Collections.emptyList()
-
-    override def isUpdateNeeded: Boolean = false
-
-    override def update(time: Time, update: MockClient.MetadataUpdate): Unit = 
{}
-  }
-}
diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaNetworkChannel.java 
b/raft/src/main/java/org/apache/kafka/raft/KafkaNetworkChannel.java
new file mode 100644
index 00000000000..2c0dd25d439
--- /dev/null
+++ b/raft/src/main/java/org/apache/kafka/raft/KafkaNetworkChannel.java
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.raft;
+
+import org.apache.kafka.clients.ClientResponse;
+import org.apache.kafka.clients.KafkaClient;
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.message.BeginQuorumEpochRequestData;
+import org.apache.kafka.common.message.EndQuorumEpochRequestData;
+import org.apache.kafka.common.message.FetchRequestData;
+import org.apache.kafka.common.message.FetchSnapshotRequestData;
+import org.apache.kafka.common.message.VoteRequestData;
+import org.apache.kafka.common.protocol.ApiKeys;
+import org.apache.kafka.common.protocol.ApiMessage;
+import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.requests.AbstractRequest;
+import org.apache.kafka.common.requests.BeginQuorumEpochRequest;
+import org.apache.kafka.common.requests.EndQuorumEpochRequest;
+import org.apache.kafka.common.requests.FetchRequest;
+import org.apache.kafka.common.requests.FetchSnapshotRequest;
+import org.apache.kafka.common.requests.VoteRequest;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.server.util.InterBrokerSendThread;
+import org.apache.kafka.server.util.RequestAndCompletionHandler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Queue;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.atomic.AtomicInteger;
+
+public class KafkaNetworkChannel implements NetworkChannel {
+
+    static class SendThread extends InterBrokerSendThread {
+
+        private Queue<RequestAndCompletionHandler> queue = new 
ConcurrentLinkedQueue<>();
+
+        public SendThread(String name, KafkaClient networkClient, int 
requestTimeoutMs, Time time, boolean isInterruptible) {
+            super(name, networkClient, requestTimeoutMs, time, 
isInterruptible);
+        }
+
+        @Override
+        public Collection<RequestAndCompletionHandler> generateRequests() {
+            List<RequestAndCompletionHandler> list =  new ArrayList<>();
+            while (true) {
+                RequestAndCompletionHandler request = queue.poll();
+                if (request == null) {
+                    return list;
+                } else {
+                    list.add(request);
+                }
+            }
+        }
+
+        public void sendRequest(RequestAndCompletionHandler request) {
+            queue.add(request);
+            wakeup();
+        }
+    }
+
+    private static final Logger log = 
LoggerFactory.getLogger(KafkaNetworkChannel.class);
+
+    private final SendThread requestThread;
+
+    private final AtomicInteger correlationIdCounter = new AtomicInteger(0);
+    private final Map<Integer, Node> endpoints = new HashMap<>();
+
+    public KafkaNetworkChannel(Time time, KafkaClient client, int 
requestTimeoutMs, String threadNamePrefix) {
+        this.requestThread = new SendThread(
+            threadNamePrefix + "-outbound-request-thread",
+            client,
+            requestTimeoutMs,
+            time,
+            false
+        );
+    }
+
+    @Override
+    public int newCorrelationId() {
+        return correlationIdCounter.getAndIncrement();
+    }
+
+    @Override
+    public void send(RaftRequest.Outbound request) {
+        Node node = endpoints.get(request.destinationId());
+        if (node != null) {
+            requestThread.sendRequest(new RequestAndCompletionHandler(
+                request.createdTimeMs,
+                node,
+                buildRequest(request.data),
+                response -> sendOnComplete(request, response)
+            ));
+        } else
+            sendCompleteFuture(request, errorResponse(request.data, 
Errors.BROKER_NOT_AVAILABLE));
+    }
+
+    private void sendCompleteFuture(RaftRequest.Outbound request, ApiMessage 
message) {
+        RaftResponse.Inbound response = new RaftResponse.Inbound(
+                request.correlationId,
+                message,
+                request.destinationId()
+        );
+        request.completion.complete(response);
+    }
+
+    private void sendOnComplete(RaftRequest.Outbound request, ClientResponse 
clientResponse) {
+        ApiMessage response;
+        if (clientResponse.versionMismatch() != null) {
+            log.error("Request {} failed due to unsupported version error", 
request, clientResponse.versionMismatch());
+            response = errorResponse(request.data, Errors.UNSUPPORTED_VERSION);
+        } else if (clientResponse.authenticationException() != null) {
+            // For now we treat authentication errors as retriable. We use the
+            // `NETWORK_EXCEPTION` error code for lack of a good alternative.
+            // Note that `NodeToControllerChannelManager` will still log the
+            // authentication errors so that users have a chance to fix the 
problem.
+            log.error("Request {} failed due to authentication error", 
request, clientResponse.authenticationException());
+            response = errorResponse(request.data, Errors.NETWORK_EXCEPTION);
+        } else if (clientResponse.wasDisconnected()) {
+            response = errorResponse(request.data, 
Errors.BROKER_NOT_AVAILABLE);
+        } else {
+            response = clientResponse.responseBody().data();
+        }
+        sendCompleteFuture(request, response);
+    }
+
+    private ApiMessage errorResponse(ApiMessage request, Errors error) {
+        ApiKeys apiKey = ApiKeys.forId(request.apiKey());
+        return RaftUtil.errorResponse(apiKey, error);
+    }
+
+    @Override
+    public void updateEndpoint(int id, RaftConfig.InetAddressSpec spec) {
+        Node node = new Node(id, spec.address.getHostString(), 
spec.address.getPort());
+        endpoints.put(id, node);
+    }
+
+    public void start() {
+        requestThread.start();
+    }
+
+    @Override
+    public void close() throws InterruptedException {
+        requestThread.shutdown();
+    }
+
+    // Visible for testing
+    public void pollOnce() {
+        requestThread.doWork();
+    }
+
+    static AbstractRequest.Builder<? extends AbstractRequest> 
buildRequest(ApiMessage requestData) {
+        if (requestData instanceof VoteRequestData)
+            return new VoteRequest.Builder((VoteRequestData) requestData);
+        if (requestData instanceof BeginQuorumEpochRequestData)
+            return new 
BeginQuorumEpochRequest.Builder((BeginQuorumEpochRequestData) requestData);
+        if (requestData instanceof EndQuorumEpochRequestData)
+            return new 
EndQuorumEpochRequest.Builder((EndQuorumEpochRequestData) requestData);
+        if (requestData instanceof FetchRequestData)
+            return new FetchRequest.SimpleBuilder((FetchRequestData) 
requestData);
+        if (requestData instanceof FetchSnapshotRequestData)
+            return new FetchSnapshotRequest.Builder((FetchSnapshotRequestData) 
requestData);
+        throw new IllegalArgumentException("Unexpected type for requestData: " 
+ requestData);
+    }
+}
diff --git a/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java 
b/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java
index e3482e56751..e527adf6f9b 100644
--- a/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java
+++ b/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java
@@ -16,13 +16,11 @@
  */
 package org.apache.kafka.raft;
 
-import java.io.Closeable;
-
 /**
  * A simple network interface with few assumptions. We do not assume ordering
  * of requests or even that every outbound request will receive a response.
  */
-public interface NetworkChannel extends Closeable {
+public interface NetworkChannel extends AutoCloseable {
 
     /**
      * Generate a new and unique correlationId for a new request to be sent.
@@ -41,6 +39,6 @@ public interface NetworkChannel extends Closeable {
      */
     void updateEndpoint(int id, RaftConfig.InetAddressSpec address);
 
-    default void close() {}
+    default void close() throws InterruptedException {}
 
 }
diff --git 
a/raft/src/test/java/org/apache/kafka/raft/KafkaNetworkChannelTest.java 
b/raft/src/test/java/org/apache/kafka/raft/KafkaNetworkChannelTest.java
new file mode 100644
index 00000000000..3a1d097fc7a
--- /dev/null
+++ b/raft/src/test/java/org/apache/kafka/raft/KafkaNetworkChannelTest.java
@@ -0,0 +1,323 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.raft;
+
+import org.apache.kafka.clients.MockClient;
+import org.apache.kafka.clients.NodeApiVersions;
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.Uuid;
+import org.apache.kafka.common.message.ApiVersionsResponseData;
+import org.apache.kafka.common.message.BeginQuorumEpochResponseData;
+import org.apache.kafka.common.message.EndQuorumEpochResponseData;
+import org.apache.kafka.common.message.FetchRequestData;
+import org.apache.kafka.common.message.FetchResponseData;
+import org.apache.kafka.common.message.VoteResponseData;
+import org.apache.kafka.common.protocol.ApiKeys;
+import org.apache.kafka.common.protocol.ApiMessage;
+import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.requests.AbstractRequest;
+import org.apache.kafka.common.requests.AbstractResponse;
+import org.apache.kafka.common.requests.ApiVersionsResponse;
+import org.apache.kafka.common.requests.BeginQuorumEpochRequest;
+import org.apache.kafka.common.requests.BeginQuorumEpochResponse;
+import org.apache.kafka.common.requests.EndQuorumEpochRequest;
+import org.apache.kafka.common.requests.EndQuorumEpochResponse;
+import org.apache.kafka.common.requests.FetchRequest;
+import org.apache.kafka.common.requests.FetchResponse;
+import org.apache.kafka.common.requests.VoteRequest;
+import org.apache.kafka.common.requests.VoteResponse;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.annotation.ApiKeyVersionsSource;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+
+import java.net.InetSocketAddress;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.stream.Collectors;
+
+import static java.util.Arrays.asList;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class KafkaNetworkChannelTest {
+
+    private static class StubMetadataUpdater implements 
MockClient.MockMetadataUpdater {
+
+        @Override
+        public List<Node> fetchNodes() {
+            return Collections.emptyList();
+        }
+
+        @Override
+        public boolean isUpdateNeeded() {
+            return false;
+        }
+
+        @Override
+        public void update(Time time, MockClient.MetadataUpdate update) { }
+    }
+
+    private static final List<ApiKeys> RAFT_APIS = asList(
+        ApiKeys.VOTE,
+        ApiKeys.BEGIN_QUORUM_EPOCH,
+        ApiKeys.END_QUORUM_EPOCH,
+        ApiKeys.FETCH
+    );
+
+    private final String clusterId = "clusterId";
+    private final int requestTimeoutMs = 30000;
+    private final Time time = new MockTime();
+    private final MockClient client = new MockClient(time, new 
StubMetadataUpdater());
+    private final TopicPartition topicPartition = new TopicPartition("topic", 
0);
+    private final Uuid topicId = Uuid.randomUuid();
+    private final KafkaNetworkChannel channel = new KafkaNetworkChannel(time, 
client, requestTimeoutMs, "test-raft");
+
+    @BeforeEach
+    public void setupSupportedApis() {
+        List<ApiVersionsResponseData.ApiVersion> supportedApis = 
RAFT_APIS.stream().map(
+            ApiVersionsResponse::toApiVersion).collect(Collectors.toList());
+        client.setNodeApiVersions(NodeApiVersions.create(supportedApis));
+    }
+
+    @Test
+    public void testSendToUnknownDestination() throws ExecutionException, 
InterruptedException {
+        int destinationId = 2;
+        assertBrokerNotAvailable(destinationId);
+    }
+
+    @Test
+    public void testSendToBlackedOutDestination() throws ExecutionException, 
InterruptedException {
+        int destinationId = 2;
+        Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
+        channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
+                new InetSocketAddress(destinationNode.host(), 
destinationNode.port())));
+        client.backoff(destinationNode, 500);
+        assertBrokerNotAvailable(destinationId);
+    }
+
+    @Test
+    public void testWakeupClientOnSend() throws InterruptedException, 
ExecutionException {
+        int destinationId = 2;
+        Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
+        channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
+                new InetSocketAddress(destinationNode.host(), 
destinationNode.port())));
+
+        client.enableBlockingUntilWakeup(1);
+
+        Thread ioThread = new Thread(() -> {
+            // Block in poll until we get the expected wakeup
+            channel.pollOnce();
+
+            // Poll a second time to send request and receive response
+            channel.pollOnce();
+        });
+
+        AbstractResponse response = 
buildResponse(buildTestErrorResponse(ApiKeys.FETCH, Errors.INVALID_REQUEST));
+        client.prepareResponseFrom(response, destinationNode, false);
+
+        ioThread.start();
+        RaftRequest.Outbound request = sendTestRequest(ApiKeys.FETCH, 
destinationId);
+
+        ioThread.join();
+        assertResponseCompleted(request, Errors.INVALID_REQUEST);
+    }
+
+    @Test
+    public void testSendAndDisconnect() throws ExecutionException, 
InterruptedException {
+        int destinationId = 2;
+        Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
+        channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
+                new InetSocketAddress(destinationNode.host(), 
destinationNode.port())));
+
+        for (ApiKeys apiKey : RAFT_APIS) {
+            AbstractResponse response = 
buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST));
+            client.prepareResponseFrom(response, destinationNode, true);
+            sendAndAssertErrorResponse(apiKey, destinationId, 
Errors.BROKER_NOT_AVAILABLE);
+        }
+    }
+
+    @Test
+    public void testSendAndFailAuthentication() throws ExecutionException, 
InterruptedException {
+        int destinationId = 2;
+        Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
+        channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
+                new InetSocketAddress(destinationNode.host(), 
destinationNode.port())));
+
+        for (ApiKeys apiKey : RAFT_APIS) {
+            client.createPendingAuthenticationError(destinationNode, 100);
+            sendAndAssertErrorResponse(apiKey, destinationId, 
Errors.NETWORK_EXCEPTION);
+
+            // reset to clear backoff time
+            client.reset();
+        }
+    }
+
+    private void assertBrokerNotAvailable(int destinationId) throws 
ExecutionException, InterruptedException {
+        for (ApiKeys apiKey : RAFT_APIS) {
+            sendAndAssertErrorResponse(apiKey, destinationId, 
Errors.BROKER_NOT_AVAILABLE);
+        }
+    }
+
+    @Test
+    public void testSendAndReceiveOutboundRequest() throws ExecutionException, 
InterruptedException {
+        int destinationId = 2;
+        Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
+        channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
+                new InetSocketAddress(destinationNode.host(), 
destinationNode.port())));
+
+        for (ApiKeys apiKey : RAFT_APIS) {
+            Errors expectedError = Errors.INVALID_REQUEST;
+            AbstractResponse response = 
buildResponse(buildTestErrorResponse(apiKey, expectedError));
+            client.prepareResponseFrom(response, destinationNode);
+            System.out.println("api key " + apiKey + ", response " + response);
+            sendAndAssertErrorResponse(apiKey, destinationId, expectedError);
+        }
+    }
+
+    @Test
+    public void testUnsupportedVersionError() throws ExecutionException, 
InterruptedException {
+        int destinationId = 2;
+        Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
+        channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
+                new InetSocketAddress(destinationNode.host(), 
destinationNode.port())));
+
+        for (ApiKeys apiKey : RAFT_APIS) {
+            client.prepareUnsupportedVersionResponse(request -> 
request.apiKey() == apiKey);
+            sendAndAssertErrorResponse(apiKey, destinationId, 
Errors.UNSUPPORTED_VERSION);
+        }
+    }
+
+    @ParameterizedTest
+    @ApiKeyVersionsSource(apiKey = ApiKeys.FETCH)
+    public void testFetchRequestDowngrade(short version) {
+        int destinationId = 2;
+        Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
+        channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
+                new InetSocketAddress(destinationNode.host(), 
destinationNode.port())));
+        sendTestRequest(ApiKeys.FETCH, destinationId);
+        channel.pollOnce();
+
+        assertEquals(1, client.requests().size());
+        AbstractRequest request = 
client.requests().peek().requestBuilder().build(version);
+
+        if (version < 15) {
+            assertTrue(((FetchRequest) request).data().replicaId() == 1);
+            assertTrue(((FetchRequest) 
request).data().replicaState().replicaId() == -1);
+        } else {
+            assertTrue(((FetchRequest) request).data().replicaId() == -1);
+            assertTrue(((FetchRequest) 
request).data().replicaState().replicaId() == 1);
+        }
+    }
+
+    private RaftRequest.Outbound sendTestRequest(ApiKeys apiKey, int 
destinationId) {
+        int correlationId = channel.newCorrelationId();
+        long createdTimeMs = time.milliseconds();
+        ApiMessage apiRequest = buildTestRequest(apiKey);
+        RaftRequest.Outbound request = new RaftRequest.Outbound(correlationId, 
apiRequest, destinationId, createdTimeMs);
+        channel.send(request);
+        return request;
+    }
+
+    private void assertResponseCompleted(RaftRequest.Outbound request, Errors 
expectedError) throws ExecutionException, InterruptedException {
+        assertTrue(request.completion.isDone());
+
+        RaftResponse.Inbound response = request.completion.get();
+        assertEquals(request.destinationId(), response.sourceId());
+        assertEquals(request.correlationId, response.correlationId);
+        assertEquals(request.data.apiKey(), response.data.apiKey());
+        assertEquals(expectedError, extractError(response.data));
+    }
+
+    private void sendAndAssertErrorResponse(ApiKeys apiKey, int destinationId, 
Errors error) throws ExecutionException, InterruptedException {
+        RaftRequest.Outbound request = sendTestRequest(apiKey, destinationId);
+        channel.pollOnce();
+        assertResponseCompleted(request, error);
+    }
+
+    private ApiMessage buildTestRequest(ApiKeys key) {
+        int leaderEpoch = 5;
+        int leaderId = 1;
+        switch (key) {
+            case BEGIN_QUORUM_EPOCH:
+                return 
BeginQuorumEpochRequest.singletonRequest(topicPartition, clusterId, 
leaderEpoch, leaderId);
+            case END_QUORUM_EPOCH:
+                return EndQuorumEpochRequest.singletonRequest(topicPartition, 
clusterId, leaderId, leaderEpoch,
+                    Collections.singletonList(2));
+            case VOTE:
+                int lastEpoch = 4;
+                return VoteRequest.singletonRequest(topicPartition, clusterId, 
leaderEpoch, leaderId, lastEpoch, 329);
+            case FETCH:
+                FetchRequestData request = 
RaftUtil.singletonFetchRequest(topicPartition, topicId, fetchPartition -> {
+                    fetchPartition
+                        .setCurrentLeaderEpoch(5)
+                        .setFetchOffset(333)
+                        .setLastFetchedEpoch(5);
+                });
+                request.setReplicaState(new 
FetchRequestData.ReplicaState().setReplicaId(1));
+                return request;
+            default:
+                throw new AssertionError("Unexpected api " + key);
+        }
+    }
+
+    private ApiMessage buildTestErrorResponse(ApiKeys key, Errors error) {
+        switch (key) {
+            case BEGIN_QUORUM_EPOCH:
+                return new 
BeginQuorumEpochResponseData().setErrorCode(error.code());
+            case END_QUORUM_EPOCH:
+                return new 
EndQuorumEpochResponseData().setErrorCode(error.code());
+            case VOTE:
+                return VoteResponse.singletonResponse(error, topicPartition, 
Errors.NONE, 1, 5, false);
+            case FETCH:
+                return new FetchResponseData().setErrorCode(error.code());
+            default:
+                throw new AssertionError("Unexpected api " + key);
+        }
+    }
+
+    private Errors extractError(ApiMessage response) {
+        short code;
+        if (response instanceof BeginQuorumEpochResponseData)
+            code = ((BeginQuorumEpochResponseData) response).errorCode();
+        else if (response instanceof EndQuorumEpochResponseData)
+            code = ((EndQuorumEpochResponseData) response).errorCode();
+        else if (response instanceof FetchResponseData)
+            code = ((FetchResponseData) response).errorCode();
+        else if (response instanceof VoteResponseData)
+            code = ((VoteResponseData) response).errorCode();
+        else
+            throw new IllegalArgumentException("Unexpected type for 
responseData: " + response);
+        return Errors.forCode(code);
+    }
+
+    private AbstractResponse buildResponse(ApiMessage responseData) {
+        if (responseData instanceof VoteResponseData)
+            return new VoteResponse((VoteResponseData) responseData);
+        if (responseData instanceof BeginQuorumEpochResponseData)
+            return new BeginQuorumEpochResponse((BeginQuorumEpochResponseData) 
responseData);
+        if (responseData instanceof EndQuorumEpochResponseData)
+            return new EndQuorumEpochResponse((EndQuorumEpochResponseData) 
responseData);
+        if (responseData instanceof FetchResponseData)
+            return new FetchResponse((FetchResponseData) responseData);
+        throw new IllegalArgumentException("Unexpected type for responseData: 
" + responseData);
+    }
+}


Reply via email to