dajac commented on code in PR #12845:
URL: https://github.com/apache/kafka/pull/12845#discussion_r1021208615


##########
core/src/test/scala/unit/kafka/server/KafkaApisTest.scala:
##########
@@ -2524,196 +2528,166 @@ class KafkaApisTest {
     assertEquals(MemoryRecords.EMPTY, 
FetchResponse.recordsOrFail(partitionData))
   }
 
-  @Test
-  def testJoinGroupProtocolsOrder(): Unit = {
-    val protocols = List(
-      ("first", "first".getBytes()),
-      ("second", "second".getBytes())
+  @ParameterizedTest
+  @ApiKeyVersionsSource(apiKey = ApiKeys.JOIN_GROUP)
+  def testHandleJoinGroupRequest(version: Short): Unit = {
+    val joinGroupRequest = new JoinGroupRequestData()
+      .setGroupId("group")
+      .setMemberId("member")
+      .setProtocolType("consumer")
+      .setRebalanceTimeoutMs(1000)
+      .setSessionTimeoutMs(2000)
+
+    val requestChannelRequest = buildRequest(new 
JoinGroupRequest.Builder(joinGroupRequest).build(version))
+
+    val expectedRequestContext = new GroupCoordinatorRequestContext(
+      version,
+      requestChannelRequest.context.clientId,
+      requestChannelRequest.context.clientAddress,
+      RequestLocal.NoCaching.bufferSupplier
     )
 
-    val groupId = "group"
-    val memberId = "member1"
-    val protocolType = "consumer"
-    val rebalanceTimeoutMs = 10
-    val sessionTimeoutMs = 5
-    val capturedProtocols: ArgumentCaptor[List[(String, Array[Byte])]] = 
ArgumentCaptor.forClass(classOf[List[(String, Array[Byte])]])
+    val expectedJoinGroupRequest = new JoinGroupRequestData()
+      .setGroupId(joinGroupRequest.groupId)
+      .setMemberId(joinGroupRequest.memberId)
+      .setProtocolType(joinGroupRequest.protocolType)
+      .setRebalanceTimeoutMs(if (version >= 1) 
joinGroupRequest.rebalanceTimeoutMs else joinGroupRequest.sessionTimeoutMs)
+      .setSessionTimeoutMs(joinGroupRequest.sessionTimeoutMs)
 
-    createKafkaApis().handleJoinGroupRequest(
-      buildRequest(
-        new JoinGroupRequest.Builder(
-          new JoinGroupRequestData()
-            .setGroupId(groupId)
-            .setMemberId(memberId)
-            .setProtocolType(protocolType)
-            .setRebalanceTimeoutMs(rebalanceTimeoutMs)
-            .setSessionTimeoutMs(sessionTimeoutMs)
-            .setProtocols(new 
JoinGroupRequestData.JoinGroupRequestProtocolCollection(
-              protocols.map { case (name, protocol) => new 
JoinGroupRequestProtocol()
-                .setName(name).setMetadata(protocol)
-              }.iterator.asJava))
-        ).build()
-      ),
-      RequestLocal.withThreadConfinedCaching)
+    val future = new CompletableFuture[JoinGroupResponseData]()
+    when(newGroupCoordinator.joinGroup(
+      ArgumentMatchers.eq(expectedRequestContext),
+      ArgumentMatchers.eq(expectedJoinGroupRequest)
+    )).thenReturn(future)
 
-    verify(groupCoordinator).handleJoinGroup(
-      ArgumentMatchers.eq(groupId),
-      ArgumentMatchers.eq(memberId),
-      ArgumentMatchers.eq(None),
-      ArgumentMatchers.eq(true),
-      ArgumentMatchers.eq(true),
-      ArgumentMatchers.eq(clientId),
-      ArgumentMatchers.eq(InetAddress.getLocalHost.toString),
-      ArgumentMatchers.eq(rebalanceTimeoutMs),
-      ArgumentMatchers.eq(sessionTimeoutMs),
-      ArgumentMatchers.eq(protocolType),
-      capturedProtocols.capture(),
-      any(),
-      any(),
-      any()
+    createKafkaApis().handleJoinGroupRequest(
+      requestChannelRequest,
+      RequestLocal.NoCaching
     )
-    val capturedProtocolsList = capturedProtocols.getValue
-    assertEquals(protocols.size, capturedProtocolsList.size)
-    protocols.zip(capturedProtocolsList).foreach { case ((expectedName, 
expectedBytes), (name, bytes)) =>
-      assertEquals(expectedName, name)
-      assertArrayEquals(expectedBytes, bytes)
-    }
-  }
 
-  @Test
-  def testJoinGroupWhenAnErrorOccurs(): Unit = {
-    for (version <- ApiKeys.JOIN_GROUP.oldestVersion to 
ApiKeys.JOIN_GROUP.latestVersion) {
-      testJoinGroupWhenAnErrorOccurs(version.asInstanceOf[Short])
-    }
+    val expectedJoinGroupResponse = new JoinGroupResponseData()
+      .setMemberId("member")
+      .setGenerationId(0)
+      .setLeader("leader")
+      .setProtocolType("consumer")
+      .setProtocolName("range")
+
+    future.complete(expectedJoinGroupResponse)
+    val capturedResponse = verifyNoThrottling(requestChannelRequest)
+    val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse]
+    assertEquals(expectedJoinGroupResponse, response.data)
   }
 
-  def testJoinGroupWhenAnErrorOccurs(version: Short): Unit = {
-    reset(groupCoordinator, clientRequestQuotaManager, requestChannel, 
replicaManager)
+  @ParameterizedTest
+  @ApiKeyVersionsSource(apiKey = ApiKeys.JOIN_GROUP)
+  def testJoinGroupProtocolTypeBackwardCompatibility(version: Short): Unit = {
+    val joinGroupRequest = new JoinGroupRequestData()
+      .setGroupId("group")
+      .setMemberId("member")
+      .setProtocolType("consumer")
+      .setRebalanceTimeoutMs(1000)
+      .setSessionTimeoutMs(2000)
+
+    val requestChannelRequest = buildRequest(new 
JoinGroupRequest.Builder(joinGroupRequest).build(version))
+
+    val expectedRequestContext = new GroupCoordinatorRequestContext(
+      version,
+      requestChannelRequest.context.clientId,
+      requestChannelRequest.context.clientAddress,
+      RequestLocal.NoCaching.bufferSupplier
+    )
 
-    val groupId = "group"
-    val memberId = "member1"
-    val protocolType = "consumer"
-    val rebalanceTimeoutMs = 10
-    val sessionTimeoutMs = 5
+    val expectedJoinGroupRequest = new JoinGroupRequestData()
+      .setGroupId(joinGroupRequest.groupId)
+      .setMemberId(joinGroupRequest.memberId)
+      .setProtocolType(joinGroupRequest.protocolType)
+      .setRebalanceTimeoutMs(if (version >= 1) 
joinGroupRequest.rebalanceTimeoutMs else joinGroupRequest.sessionTimeoutMs)
+      .setSessionTimeoutMs(joinGroupRequest.sessionTimeoutMs)
 
-    val capturedCallback: ArgumentCaptor[JoinGroupCallback] = 
ArgumentCaptor.forClass(classOf[JoinGroupCallback])
+    val future = new CompletableFuture[JoinGroupResponseData]()
+    when(newGroupCoordinator.joinGroup(
+      ArgumentMatchers.eq(expectedRequestContext),
+      ArgumentMatchers.eq(expectedJoinGroupRequest)
+    )).thenReturn(future)
 
-    val joinGroupRequest = new JoinGroupRequest.Builder(
-      new JoinGroupRequestData()
-        .setGroupId(groupId)
-        .setMemberId(memberId)
-        .setProtocolType(protocolType)
-        .setRebalanceTimeoutMs(rebalanceTimeoutMs)
-        .setSessionTimeoutMs(sessionTimeoutMs)
-    ).build(version)
+    createKafkaApis().handleJoinGroupRequest(
+      requestChannelRequest,
+      RequestLocal.NoCaching
+    )
 
-    val requestChannelRequest = buildRequest(joinGroupRequest)
+    val joinGroupResponse = new JoinGroupResponseData()
+      .setErrorCode(Errors.INCONSISTENT_GROUP_PROTOCOL.code)
+      .setMemberId("member")
 
-    createKafkaApis().handleJoinGroupRequest(requestChannelRequest, 
RequestLocal.withThreadConfinedCaching)
-
-    verify(groupCoordinator).handleJoinGroup(
-      ArgumentMatchers.eq(groupId),
-      ArgumentMatchers.eq(memberId),
-      ArgumentMatchers.eq(None),
-      ArgumentMatchers.eq(if (version >= 4) true else false),
-      ArgumentMatchers.eq(if (version >= 9) true else false),
-      ArgumentMatchers.eq(clientId),
-      ArgumentMatchers.eq(InetAddress.getLocalHost.toString),
-      ArgumentMatchers.eq(if (version >= 1) rebalanceTimeoutMs else 
sessionTimeoutMs),
-      ArgumentMatchers.eq(sessionTimeoutMs),
-      ArgumentMatchers.eq(protocolType),
-      ArgumentMatchers.eq(List.empty),
-      capturedCallback.capture(),
-      any(),
-      any()
-    )
-    capturedCallback.getValue.apply(JoinGroupResult(memberId, 
Errors.INCONSISTENT_GROUP_PROTOCOL))
+    val expectedJoinGroupResponse = new JoinGroupResponseData()
+      .setErrorCode(Errors.INCONSISTENT_GROUP_PROTOCOL.code)
+      .setMemberId("member")
+      .setProtocolType(if (version >= 7) null else GroupCoordinator.NoProtocol)
 
+    future.complete(joinGroupResponse)
     val capturedResponse = verifyNoThrottling(requestChannelRequest)
     val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse]
-
-    assertEquals(Errors.INCONSISTENT_GROUP_PROTOCOL, response.error)
-    assertEquals(0, response.data.members.size)
-    assertEquals(memberId, response.data.memberId)
-    assertEquals(GroupCoordinator.NoGeneration, response.data.generationId)
-    assertEquals(GroupCoordinator.NoLeader, response.data.leader)
-    assertNull(response.data.protocolType)
-
-    if (version >= 7) {
-      assertNull(response.data.protocolName)
-    } else {
-      assertEquals(GroupCoordinator.NoProtocol, response.data.protocolName)
-    }
+    assertEquals(expectedJoinGroupResponse, response.data)
   }
 
   @Test
-  def testJoinGroupProtocolType(): Unit = {
-    for (version <- ApiKeys.JOIN_GROUP.oldestVersion to 
ApiKeys.JOIN_GROUP.latestVersion) {
-      testJoinGroupProtocolType(version.asInstanceOf[Short])
-    }
-  }
+  def testHandleJoinGroupRequestFutureFailed(): Unit = {
+    val joinGroupRequest = new JoinGroupRequestData()
+      .setGroupId("group")
+      .setMemberId("member")
+      .setProtocolType("consumer")
+      .setRebalanceTimeoutMs(1000)
+      .setSessionTimeoutMs(2000)
 
-  def testJoinGroupProtocolType(version: Short): Unit = {
-    reset(groupCoordinator, clientRequestQuotaManager, requestChannel, 
replicaManager)
+    val requestChannelRequest = buildRequest(new 
JoinGroupRequest.Builder(joinGroupRequest).build())
 
-    val groupId = "group"
-    val memberId = "member1"
-    val protocolType = "consumer"
-    val protocolName = "range"
-    val rebalanceTimeoutMs = 10
-    val sessionTimeoutMs = 5
+    val expectedRequestContext = new GroupCoordinatorRequestContext(
+      ApiKeys.JOIN_GROUP.latestVersion,
+      requestChannelRequest.context.clientId,
+      requestChannelRequest.context.clientAddress,
+      RequestLocal.NoCaching.bufferSupplier
+    )
 
-    val capturedCallback: ArgumentCaptor[JoinGroupCallback] = 
ArgumentCaptor.forClass(classOf[JoinGroupCallback])
+    val future = new CompletableFuture[JoinGroupResponseData]()
+    when(newGroupCoordinator.joinGroup(
+      ArgumentMatchers.eq(expectedRequestContext),
+      ArgumentMatchers.eq(joinGroupRequest)
+    )).thenReturn(future)
 
-    val joinGroupRequest = new JoinGroupRequest.Builder(
-      new JoinGroupRequestData()
-        .setGroupId(groupId)
-        .setMemberId(memberId)
-        .setProtocolType(protocolType)
-        .setRebalanceTimeoutMs(rebalanceTimeoutMs)
-        .setSessionTimeoutMs(sessionTimeoutMs)
-    ).build(version)
+    createKafkaApis().handleJoinGroupRequest(
+      requestChannelRequest,
+      RequestLocal.NoCaching
+    )
 
-    val requestChannelRequest = buildRequest(joinGroupRequest)
+    future.completeExceptionally(Errors.REQUEST_TIMED_OUT.exception)
+    val capturedResponse = verifyNoThrottling(requestChannelRequest)
+    val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse]
+    assertEquals(Errors.REQUEST_TIMED_OUT, response.error)
+  }
 
-    createKafkaApis().handleJoinGroupRequest(requestChannelRequest, 
RequestLocal.withThreadConfinedCaching)
+  @Test
+  def testHandleJoinGroupRequestAuthorizationFailed(): Unit = {

Review Comment:
   This is a new test. This code path was not tested before from 
`KafkaApisTest`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to