squah-confluent commented on code in PR #22264:
URL: https://github.com/apache/kafka/pull/22264#discussion_r3264779900


##########
clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerProtocol.java:
##########
@@ -116,8 +116,8 @@ public static Subscription deserializeSubscription(final 
ByteBuffer buffer, shor
                 ownedPartitions,
                 data.generationId(),
                 data.rackId() == null || data.rackId().isEmpty() ? 
Optional.empty() : Optional.of(data.rackId()));
-        } catch (BufferUnderflowException e) {
-            throw new SchemaException("Buffer underflow while parsing consumer 
protocol's subscription", e);
+        } catch (RuntimeException e) {
+            throw new SchemaException("Malformed consumer protocol's 
subscription", e);

Review Comment:
   typo
   ```suggestion
               throw new SchemaException("Malformed consumer protocol 
subscription", e);
   ```
   
   same for the other 3 SchemaExceptions thrown in this file



##########
clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerProtocolTest.java:
##########
@@ -356,4 +358,84 @@ private ByteBuffer generateFutureSubscriptionVersionData() 
{
 
         return buffer;
     }
+
+    @Test
+    public void 
testDeserializeSubscriptionThrowsSchemaExceptionForEveryTruncation() {

Review Comment:
   The rest of the existing test methods in this class omit the "test" prefix



##########
group-coordinator/src/test/java/org/apache/kafka/coordinator/group/GroupMetadataManagerTest.java:
##########
@@ -12172,6 +12172,140 @@ memberId2, new MemberAssignmentImpl(mkAssignment(
         }
     }
 
+    @Test
+    public void testUpgradeFailsOnMalformedClassicGroupProtocol() {

Review Comment:
   Trying to figure out what our existing naming convention is. Looks like we 
call the upgrade tests 
`testConsumerGroupHeartbeatWithClassicGroup/ToClassicGroup`?
   
   `testConsumerGroupHeartbeatToClassicGroupWithMalformedSubscriptionFails`?



##########
group-coordinator/src/test/java/org/apache/kafka/coordinator/group/GroupMetadataManagerTest.java:
##########
@@ -12172,6 +12172,140 @@ memberId2, new MemberAssignmentImpl(mkAssignment(
         }
     }
 
+    @Test
+    public void testUpgradeFailsOnMalformedClassicGroupProtocol() {
+        String groupId = "group-id";
+        String memberId1 = "member-id-1";
+        String memberId2 = "member-id-2";
+        Uuid fooTopicId = Uuid.randomUuid();
+        String fooTopicName = "foo";
+
+        MockPartitionAssignor assignor = new MockPartitionAssignor("range");
+
+        CoordinatorMetadataImage metadataImage = new MetadataImageBuilder()
+            .addTopic(fooTopicId, fooTopicName, 1)
+            .addRacks()
+            .buildCoordinatorMetadataImage();
+
+        GroupMetadataManagerTestContext context = new 
GroupMetadataManagerTestContext.Builder()
+            
.withConfig(GroupCoordinatorConfig.CONSUMER_GROUP_MIGRATION_POLICY_CONFIG, 
ConsumerGroupMigrationPolicy.UPGRADE.toString())
+            
.withConfig(GroupCoordinatorConfig.CONSUMER_GROUP_ASSIGNORS_CONFIG, 
List.of(assignor))
+            .withMetadataImage(metadataImage)
+            .build();
+
+        // Throws RuntimeException when read
+        byte[] poisonMetadata = new byte[]{0, 1, (byte) 0xFF, (byte) 0xFF, 
(byte) 0xFF, (byte) 0xFF};
+
+        JoinGroupRequestData.JoinGroupRequestProtocolCollection protocols = 
new JoinGroupRequestData.JoinGroupRequestProtocolCollection(1);
+        protocols.add(new JoinGroupRequestData.JoinGroupRequestProtocol()
+            .setName("range")
+            .setMetadata(poisonMetadata));
+
+        Map<String, byte[]> assignments = Map.of(
+            memberId1,
+            Utils.toArray(ConsumerProtocol.serializeAssignment(
+                new ConsumerPartitionAssignor.Assignment(List.of(new 
TopicPartition(fooTopicName, 0)))))
+        );
+
+        ClassicGroup group = context.createClassicGroup(groupId);
+        group.setProtocolName(Optional.of("range"));
+        group.add(
+            new ClassicGroupMember(
+                memberId1,
+                Optional.empty(),
+                "client-id",
+                "client-host",
+                10000,
+                5000,
+                "consumer",
+                protocols,
+                assignments.get(memberId1)
+            )
+        );
+
+        group.transitionTo(PREPARING_REBALANCE);
+        group.transitionTo(COMPLETING_REBALANCE);
+        group.transitionTo(STABLE);
+
+        
context.replay(GroupCoordinatorRecordHelpers.newGroupMetadataRecord(group, 
assignments));
+        context.commit();
+
+        // A new member 2 with the new protocol joins the classic group, 
triggering the upgrade.
+        ConsumerGroupHeartbeatRequestData request = new 
ConsumerGroupHeartbeatRequestData()
+            .setGroupId(groupId)
+            .setMemberId(memberId2)
+            .setRebalanceTimeoutMs(5000)
+            .setServerAssignor("range")
+            .setSubscribedTopicNames(List.of(fooTopicName))
+            .setTopicPartitions(List.of());
+
+        Exception ex = assertThrows(GroupIdNotFoundException.class,
+            () -> context.consumerGroupHeartbeat(request));
+        assertEquals(
+            "Cannot upgrade classic group group-id to consumer group because 
the embedded consumer protocol is malformed.",
+            ex.getMessage()
+        );
+    }
+
+    @Test
+    public void 
testClassicJoinToConsumerGroupFailsOnMalformedSubscriptionMetadata() {

Review Comment:
   ```suggestion
       public void 
testClassicGroupJoinToConsumerGroupFailsOnMalformedSubscriptionMetadata() {
   ```



##########
clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerProtocolTest.java:
##########
@@ -356,4 +358,84 @@ private ByteBuffer generateFutureSubscriptionVersionData() 
{
 
         return buffer;
     }
+
+    @Test
+    public void 
testDeserializeSubscriptionThrowsSchemaExceptionForEveryTruncation() {
+        Subscription subscription = new Subscription(
+            Arrays.asList("foo", "bar"),
+            ByteBuffer.wrap(new byte[]{0x01, 0x02}),
+            Arrays.asList(new TopicPartition("foo", 0), new 
TopicPartition("bar", 1)),
+            DEFAULT_GENERATION,
+            Optional.of("rack"));
+        byte[] serialized = 
toBytes(ConsumerProtocol.serializeSubscription(subscription,
+            ConsumerProtocolSubscription.HIGHEST_SUPPORTED_VERSION));
+
+        for (int len = 0; len < serialized.length; len++) {
+            byte[] truncated = Arrays.copyOf(serialized, len);
+            assertThrows(SchemaException.class,
+                () -> 
ConsumerProtocol.deserializeSubscription(ByteBuffer.wrap(truncated)),
+                "Expected SchemaException for subscription truncated to length 
" + len);
+        }
+        ConsumerProtocol.deserializeSubscription(ByteBuffer.wrap(serialized));
+    }
+
+    @Test
+    public void 
testDeserializeAssignmentThrowsSchemaExceptionForEveryTruncation() {
+        Assignment assignment = new Assignment(
+            Arrays.asList(new TopicPartition("foo", 0), new 
TopicPartition("bar", 1)),
+            ByteBuffer.wrap(new byte[]{0x01, 0x02}));
+        byte[] serialized = 
toBytes(ConsumerProtocol.serializeAssignment(assignment,
+            ConsumerProtocolAssignment.HIGHEST_SUPPORTED_VERSION));
+
+        for (int len = 0; len < serialized.length; len++) {
+            byte[] truncated = Arrays.copyOf(serialized, len);
+            assertThrows(SchemaException.class,
+                () -> 
ConsumerProtocol.deserializeAssignment(ByteBuffer.wrap(truncated)),
+                "Expected SchemaException for assignment truncated to length " 
+ len);
+        }
+        ConsumerProtocol.deserializeAssignment(ByteBuffer.wrap(serialized));
+    }
+
+    @Test
+    public void 
testDeserializeConsumerProtocolSubscriptionThrowsSchemaExceptionForEveryTruncation()
 {
+        Subscription subscription = new Subscription(
+            Arrays.asList("foo", "bar"),
+            ByteBuffer.wrap(new byte[]{0x01, 0x02}),
+            Arrays.asList(new TopicPartition("foo", 0), new 
TopicPartition("bar", 1)),
+            DEFAULT_GENERATION,
+            Optional.of("rack"));
+        byte[] serialized = 
toBytes(ConsumerProtocol.serializeSubscription(subscription,
+            ConsumerProtocolSubscription.HIGHEST_SUPPORTED_VERSION));
+
+        for (int len = 0; len < serialized.length; len++) {
+            byte[] truncated = Arrays.copyOf(serialized, len);
+            assertThrows(SchemaException.class,
+                () -> 
ConsumerProtocol.deserializeConsumerProtocolSubscription(ByteBuffer.wrap(truncated)),
+                "Expected SchemaException for ConsumerProtocolSubscription 
truncated to length " + len);
+        }
+        
ConsumerProtocol.deserializeConsumerProtocolSubscription(ByteBuffer.wrap(serialized));
+    }
+
+    @Test
+    public void 
testDeserializeConsumerProtocolAssignmentThrowsSchemaExceptionForEveryTruncation()
 {
+        Assignment assignment = new Assignment(
+            Arrays.asList(new TopicPartition("foo", 0), new 
TopicPartition("bar", 1)),
+            ByteBuffer.wrap(new byte[]{0x01, 0x02}));
+        byte[] serialized = 
toBytes(ConsumerProtocol.serializeAssignment(assignment,
+            ConsumerProtocolAssignment.HIGHEST_SUPPORTED_VERSION));
+
+        for (int len = 0; len < serialized.length; len++) {
+            byte[] truncated = Arrays.copyOf(serialized, len);
+            assertThrows(SchemaException.class,
+                () -> 
ConsumerProtocol.deserializeConsumerProtocolAssignment(ByteBuffer.wrap(truncated)),
+                "Expected SchemaException for ConsumerProtocolAssignment 
truncated to length " + len);
+        }
+        
ConsumerProtocol.deserializeConsumerProtocolAssignment(ByteBuffer.wrap(serialized));
+    }
+
+    private static byte[] toBytes(ByteBuffer buffer) {
+        byte[] arr = new byte[buffer.remaining()];
+        buffer.duplicate().get(arr);
+        return arr;
+    }

Review Comment:
   I would inline this as it's two lines.



##########
clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerProtocolTest.java:
##########
@@ -356,4 +358,84 @@ private ByteBuffer generateFutureSubscriptionVersionData() 
{
 
         return buffer;
     }
+
+    @Test
+    public void 
testDeserializeSubscriptionThrowsSchemaExceptionForEveryTruncation() {
+        Subscription subscription = new Subscription(
+            Arrays.asList("foo", "bar"),
+            ByteBuffer.wrap(new byte[]{0x01, 0x02}),
+            Arrays.asList(new TopicPartition("foo", 0), new 
TopicPartition("bar", 1)),
+            DEFAULT_GENERATION,
+            Optional.of("rack"));
+        byte[] serialized = 
toBytes(ConsumerProtocol.serializeSubscription(subscription,
+            ConsumerProtocolSubscription.HIGHEST_SUPPORTED_VERSION));
+
+        for (int len = 0; len < serialized.length; len++) {
+            byte[] truncated = Arrays.copyOf(serialized, len);
+            assertThrows(SchemaException.class,
+                () -> 
ConsumerProtocol.deserializeSubscription(ByteBuffer.wrap(truncated)),
+                "Expected SchemaException for subscription truncated to length 
" + len);
+        }
+        ConsumerProtocol.deserializeSubscription(ByteBuffer.wrap(serialized));

Review Comment:
   nit: There is assertDoesNotThrow which we can use here to signal our intent
   
   same for the other 3 new tests in this file



##########
group-coordinator/src/test/java/org/apache/kafka/coordinator/group/classic/ClassicGroupTest.java:
##########
@@ -1264,6 +1264,34 @@ public void testIsSubscribedToTopic() {
         assertTrue(group.isSubscribedToTopic("topic"));
     }
 
+    @Test
+    public void testComputeSubscribedTopicsHandlesMalformedMemberMetadata() {
+        ClassicGroup group = new ClassicGroup(logContext, "groupId", EMPTY, 
Time.SYSTEM);
+
+        JoinGroupRequestProtocolCollection protocols = new 
JoinGroupRequestProtocolCollection();
+        protocols.add(new JoinGroupRequestProtocol()
+            .setName("range")
+            .setMetadata(new byte[]{0, 1, (byte) 0xFF, (byte) 0xFF, (byte) 
0xFF, (byte) 0xFF})); // Throws RuntimeException

Review Comment:
   Could we reformat this to explain what the interpretation is?
   
   eg.
   ```
   new byte[]{
       // field1 = 0
       0,
       // field2 = 1
       1,
       // field3 = -1
       (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF
   }
   ```



##########
group-coordinator/src/test/java/org/apache/kafka/coordinator/group/classic/ClassicGroupTest.java:
##########
@@ -1264,6 +1264,34 @@ public void testIsSubscribedToTopic() {
         assertTrue(group.isSubscribedToTopic("topic"));
     }
 
+    @Test
+    public void testComputeSubscribedTopicsHandlesMalformedMemberMetadata() {
+        ClassicGroup group = new ClassicGroup(logContext, "groupId", EMPTY, 
Time.SYSTEM);
+
+        JoinGroupRequestProtocolCollection protocols = new 
JoinGroupRequestProtocolCollection();
+        protocols.add(new JoinGroupRequestProtocol()
+            .setName("range")
+            .setMetadata(new byte[]{0, 1, (byte) 0xFF, (byte) 0xFF, (byte) 
0xFF, (byte) 0xFF})); // Throws RuntimeException
+
+        ClassicGroupMember poisonMember = new ClassicGroupMember(
+            "poisonMember",
+            Optional.empty(),
+            clientId,
+            clientHost,
+            rebalanceTimeoutMs,
+            sessionTimeoutMs,
+            "consumer",
+            protocols
+        );
+
+        group.add(poisonMember);
+        group.transitionTo(PREPARING_REBALANCE);
+        group.initNextGeneration();
+
+        // Should not propagate; falls through to Optional.empty().

Review Comment:
   ```suggestion
           // RuntimeException should not propagate; falls through to 
Optional.empty().
   ```



##########
group-coordinator/src/test/java/org/apache/kafka/coordinator/group/GroupMetadataManagerTest.java:
##########
@@ -12172,6 +12172,140 @@ memberId2, new MemberAssignmentImpl(mkAssignment(
         }
     }
 
+    @Test
+    public void testUpgradeFailsOnMalformedClassicGroupProtocol() {
+        String groupId = "group-id";
+        String memberId1 = "member-id-1";
+        String memberId2 = "member-id-2";
+        Uuid fooTopicId = Uuid.randomUuid();
+        String fooTopicName = "foo";
+
+        MockPartitionAssignor assignor = new MockPartitionAssignor("range");
+
+        CoordinatorMetadataImage metadataImage = new MetadataImageBuilder()
+            .addTopic(fooTopicId, fooTopicName, 1)
+            .addRacks()
+            .buildCoordinatorMetadataImage();
+
+        GroupMetadataManagerTestContext context = new 
GroupMetadataManagerTestContext.Builder()
+            
.withConfig(GroupCoordinatorConfig.CONSUMER_GROUP_MIGRATION_POLICY_CONFIG, 
ConsumerGroupMigrationPolicy.UPGRADE.toString())
+            
.withConfig(GroupCoordinatorConfig.CONSUMER_GROUP_ASSIGNORS_CONFIG, 
List.of(assignor))
+            .withMetadataImage(metadataImage)
+            .build();
+
+        // Throws RuntimeException when read
+        byte[] poisonMetadata = new byte[]{0, 1, (byte) 0xFF, (byte) 0xFF, 
(byte) 0xFF, (byte) 0xFF};
+
+        JoinGroupRequestData.JoinGroupRequestProtocolCollection protocols = 
new JoinGroupRequestData.JoinGroupRequestProtocolCollection(1);
+        protocols.add(new JoinGroupRequestData.JoinGroupRequestProtocol()
+            .setName("range")
+            .setMetadata(poisonMetadata));
+
+        Map<String, byte[]> assignments = Map.of(
+            memberId1,
+            Utils.toArray(ConsumerProtocol.serializeAssignment(
+                new ConsumerPartitionAssignor.Assignment(List.of(new 
TopicPartition(fooTopicName, 0)))))
+        );
+
+        ClassicGroup group = context.createClassicGroup(groupId);
+        group.setProtocolName(Optional.of("range"));
+        group.add(
+            new ClassicGroupMember(
+                memberId1,
+                Optional.empty(),
+                "client-id",
+                "client-host",
+                10000,
+                5000,
+                "consumer",
+                protocols,
+                assignments.get(memberId1)
+            )
+        );
+
+        group.transitionTo(PREPARING_REBALANCE);
+        group.transitionTo(COMPLETING_REBALANCE);
+        group.transitionTo(STABLE);
+
+        
context.replay(GroupCoordinatorRecordHelpers.newGroupMetadataRecord(group, 
assignments));
+        context.commit();
+
+        // A new member 2 with the new protocol joins the classic group, 
triggering the upgrade.
+        ConsumerGroupHeartbeatRequestData request = new 
ConsumerGroupHeartbeatRequestData()
+            .setGroupId(groupId)
+            .setMemberId(memberId2)
+            .setRebalanceTimeoutMs(5000)
+            .setServerAssignor("range")
+            .setSubscribedTopicNames(List.of(fooTopicName))
+            .setTopicPartitions(List.of());
+
+        Exception ex = assertThrows(GroupIdNotFoundException.class,
+            () -> context.consumerGroupHeartbeat(request));
+        assertEquals(
+            "Cannot upgrade classic group group-id to consumer group because 
the embedded consumer protocol is malformed.",
+            ex.getMessage()
+        );
+    }
+
+    @Test
+    public void 
testClassicJoinToConsumerGroupFailsOnMalformedSubscriptionMetadata() {
+        String groupId = "group-id";
+        String existingMemberId = Uuid.randomUuid().toString();
+        String newMemberId = Uuid.randomUuid().toString();
+        Uuid fooTopicId = Uuid.randomUuid();
+        String fooTopicName = "foo";
+
+        CoordinatorMetadataImage metadataImage = new MetadataImageBuilder()
+            .addTopic(fooTopicId, fooTopicName, 1)
+            .addRacks()
+            .buildCoordinatorMetadataImage();
+
+        ConsumerGroupMember existingMember = new 
ConsumerGroupMember.Builder(existingMemberId)
+            .setState(MemberState.STABLE)
+            .setMemberEpoch(10)
+            .setPreviousMemberEpoch(9)
+            .setClientId(DEFAULT_CLIENT_ID)
+            .setClientHost(DEFAULT_CLIENT_ADDRESS.toString())
+            .setSubscribedTopicNames(List.of(fooTopicName))
+            .setServerAssignorName("range")
+            .setRebalanceTimeoutMs(45000)
+            .setAssignedPartitions(toAssignmentWithEpochs(mkAssignment(
+                mkTopicAssignment(fooTopicId, 0)), 10))
+            .build();
+
+        GroupMetadataManagerTestContext context = new 
GroupMetadataManagerTestContext.Builder()
+            
.withConfig(GroupCoordinatorConfig.CONSUMER_GROUP_MIGRATION_POLICY_CONFIG, 
ConsumerGroupMigrationPolicy.UPGRADE.toString())
+            
.withConfig(GroupCoordinatorConfig.CONSUMER_GROUP_ASSIGNORS_CONFIG, List.of(new 
MockPartitionAssignor("range")))
+            .withMetadataImage(metadataImage)
+            .withConsumerGroup(new ConsumerGroupBuilder(groupId, 10)
+                .withMember(existingMember)
+                .withAssignment(existingMemberId, 
mkAssignment(mkTopicAssignment(fooTopicId, 0)))
+                .withAssignmentEpoch(10)
+                .withMetadataHash(computeGroupHash(Map.of(
+                    fooTopicName, computeTopicHash(fooTopicName, metadataImage)
+                ))))
+            .build();
+
+        // Throws RuntimeException when read.
+        byte[] poisonMetadata = new byte[]{0, 1, (byte) 0xFF, (byte) 0xFF, 
(byte) 0xFF, (byte) 0xFF};
+        JoinGroupRequestData.JoinGroupRequestProtocolCollection protocols = 
new JoinGroupRequestData.JoinGroupRequestProtocolCollection(1);
+        protocols.add(new JoinGroupRequestData.JoinGroupRequestProtocol()
+            .setName("range")
+            .setMetadata(poisonMetadata));
+
+        JoinGroupRequestData joinRequest = new JoinGroupRequestData()
+            .setGroupId(groupId)
+            .setMemberId(newMemberId)
+            .setProtocolType(ConsumerProtocol.PROTOCOL_TYPE)
+            .setProtocols(protocols)
+            .setSessionTimeoutMs(5000)
+            .setRebalanceTimeoutMs(45000);
+
+        IllegalStateException ex = assertThrows(IllegalStateException.class,
+            () -> context.sendClassicGroupJoin(joinRequest));
+        assertEquals("Malformed embedded consumer protocol in subscription 
deserialization.", ex.getMessage());

Review Comment:
   Out of interest, what error code do we report to clients in this case?



##########
group-coordinator/src/test/java/org/apache/kafka/coordinator/group/GroupMetadataManagerTest.java:
##########
@@ -12172,6 +12172,140 @@ memberId2, new MemberAssignmentImpl(mkAssignment(
         }
     }
 
+    @Test
+    public void testUpgradeFailsOnMalformedClassicGroupProtocol() {
+        String groupId = "group-id";
+        String memberId1 = "member-id-1";
+        String memberId2 = "member-id-2";
+        Uuid fooTopicId = Uuid.randomUuid();
+        String fooTopicName = "foo";
+
+        MockPartitionAssignor assignor = new MockPartitionAssignor("range");
+
+        CoordinatorMetadataImage metadataImage = new MetadataImageBuilder()
+            .addTopic(fooTopicId, fooTopicName, 1)
+            .addRacks()
+            .buildCoordinatorMetadataImage();
+
+        GroupMetadataManagerTestContext context = new 
GroupMetadataManagerTestContext.Builder()
+            
.withConfig(GroupCoordinatorConfig.CONSUMER_GROUP_MIGRATION_POLICY_CONFIG, 
ConsumerGroupMigrationPolicy.UPGRADE.toString())
+            
.withConfig(GroupCoordinatorConfig.CONSUMER_GROUP_ASSIGNORS_CONFIG, 
List.of(assignor))
+            .withMetadataImage(metadataImage)
+            .build();
+
+        // Throws RuntimeException when read
+        byte[] poisonMetadata = new byte[]{0, 1, (byte) 0xFF, (byte) 0xFF, 
(byte) 0xFF, (byte) 0xFF};

Review Comment:
   Could we reformat this as suggested above? Same for the next test method too.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to