This is an automated email from the ASF dual-hosted git repository.

mjsax pushed a commit to branch 4.1
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/4.1 by this push:
     new 173f4cf0262 MINOR: improve AssignmentInfo (#22465)
173f4cf0262 is described below

commit 173f4cf026286c533e55e5b529a4be60d4ea7bcf
Author: Matthias J. Sax <[email protected]>
AuthorDate: Fri Jun 5 08:57:48 2026 -0700

    MINOR: improve AssignmentInfo (#22465)
    
    Reviewers: Lucas Brutschy <[email protected]>
---
 .../internals/assignment/AssignmentInfo.java       | 107 +++++++++++++--------
 1 file changed, 69 insertions(+), 38 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java
index 350426edd8e..6dc60fb6773 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java
@@ -309,6 +309,7 @@ public class AssignmentInfo {
         // ensure we are at the beginning of the ByteBuffer
         data.rewind();
 
+        final int length = data.remaining();
         try (final DataInputStream in = new DataInputStream(new 
ByteBufferInputStream(data))) {
             final AssignmentInfo assignmentInfo;
 
@@ -317,45 +318,45 @@ public class AssignmentInfo {
             switch (usedVersion) {
                 case 1:
                     assignmentInfo = new AssignmentInfo(usedVersion, UNKNOWN);
-                    decodeActiveTasks(assignmentInfo, in);
-                    decodeStandbyTasks(assignmentInfo, in);
+                    decodeActiveTasks(assignmentInfo, in, length);
+                    decodeStandbyTasks(assignmentInfo, in, length);
                     assignmentInfo.partitionsByHost = new HashMap<>();
                     break;
                 case 2:
                     assignmentInfo = new AssignmentInfo(usedVersion, UNKNOWN);
-                    decodeActiveTasks(assignmentInfo, in);
-                    decodeStandbyTasks(assignmentInfo, in);
-                    decodePartitionsByHost(assignmentInfo, in);
+                    decodeActiveTasks(assignmentInfo, in, length);
+                    decodeStandbyTasks(assignmentInfo, in, length);
+                    decodePartitionsByHost(assignmentInfo, in, length);
                     break;
                 case 3:
                     commonlySupportedVersion = in.readInt();
                     assignmentInfo = new AssignmentInfo(usedVersion, 
commonlySupportedVersion);
-                    decodeActiveTasks(assignmentInfo, in);
-                    decodeStandbyTasks(assignmentInfo, in);
-                    decodePartitionsByHost(assignmentInfo, in);
+                    decodeActiveTasks(assignmentInfo, in, length);
+                    decodeStandbyTasks(assignmentInfo, in, length);
+                    decodePartitionsByHost(assignmentInfo, in, length);
                     break;
                 case 4:
                     commonlySupportedVersion = in.readInt();
                     assignmentInfo = new AssignmentInfo(usedVersion, 
commonlySupportedVersion);
-                    decodeActiveTasks(assignmentInfo, in);
-                    decodeStandbyTasks(assignmentInfo, in);
-                    decodePartitionsByHost(assignmentInfo, in);
+                    decodeActiveTasks(assignmentInfo, in, length);
+                    decodeStandbyTasks(assignmentInfo, in, length);
+                    decodePartitionsByHost(assignmentInfo, in, length);
                     assignmentInfo.errCode = in.readInt();
                     break;
                 case 5:
                     commonlySupportedVersion = in.readInt();
                     assignmentInfo = new AssignmentInfo(usedVersion, 
commonlySupportedVersion);
-                    decodeActiveTasks(assignmentInfo, in);
-                    decodeStandbyTasks(assignmentInfo, in);
-                    decodePartitionsByHostUsingDictionary(assignmentInfo, in);
+                    decodeActiveTasks(assignmentInfo, in, length);
+                    decodeStandbyTasks(assignmentInfo, in, length);
+                    decodePartitionsByHostUsingDictionary(assignmentInfo, in, 
length);
                     assignmentInfo.errCode = in.readInt();
                     break;
                 case 6:
                     commonlySupportedVersion = in.readInt();
                     assignmentInfo = new AssignmentInfo(usedVersion, 
commonlySupportedVersion);
-                    decodeActiveTasks(assignmentInfo, in);
-                    decodeStandbyTasks(assignmentInfo, in);
-                    decodeActiveAndStandbyHostPartitions(assignmentInfo, in);
+                    decodeActiveTasks(assignmentInfo, in, length);
+                    decodeStandbyTasks(assignmentInfo, in, length);
+                    decodeActiveAndStandbyHostPartitions(assignmentInfo, in, 
length);
                     assignmentInfo.errCode = in.readInt();
                     break;
                 case 7:
@@ -365,9 +366,9 @@ public class AssignmentInfo {
                 case 11:
                     commonlySupportedVersion = in.readInt();
                     assignmentInfo = new AssignmentInfo(usedVersion, 
commonlySupportedVersion);
-                    decodeActiveTasks(assignmentInfo, in);
-                    decodeStandbyTasks(assignmentInfo, in);
-                    decodeActiveAndStandbyHostPartitions(assignmentInfo, in);
+                    decodeActiveTasks(assignmentInfo, in, length);
+                    decodeStandbyTasks(assignmentInfo, in, length);
+                    decodeActiveAndStandbyHostPartitions(assignmentInfo, in, 
length);
                     assignmentInfo.errCode = in.readInt();
                     assignmentInfo.nextRebalanceMs = in.readLong();
                     break;
@@ -385,8 +386,12 @@ public class AssignmentInfo {
     }
 
     private static void decodeActiveTasks(final AssignmentInfo assignmentInfo,
-                                          final DataInputStream in) throws 
IOException {
+                                          final DataInputStream in,
+                                          final int length) throws IOException 
{
         final int count = in.readInt();
+        if (count < 0 || count > length) {
+            throw new TaskAssignmentException("Corrupted user data byte[].");
+        }
         assignmentInfo.activeTasks = new ArrayList<>(count);
         for (int i = 0; i < count; i++) {
             assignmentInfo.activeTasks.add(readTaskIdFrom(in, 
assignmentInfo.usedVersion));
@@ -394,27 +399,39 @@ public class AssignmentInfo {
     }
 
     private static void decodeStandbyTasks(final AssignmentInfo assignmentInfo,
-                                           final DataInputStream in) throws 
IOException {
+                                           final DataInputStream in,
+                                           final int length) throws 
IOException {
         final int count = in.readInt();
+        if (count < 0 || count > length) {
+            throw new TaskAssignmentException("Corrupted user data byte[].");
+        }
         assignmentInfo.standbyTasks = new HashMap<>(count);
         for (int i = 0; i < count; i++) {
             final TaskId id = readTaskIdFrom(in, assignmentInfo.usedVersion);
-            assignmentInfo.standbyTasks.put(id, readTopicPartitions(in));
+            assignmentInfo.standbyTasks.put(id, readTopicPartitions(in, 
length));
         }
     }
 
     private static void decodePartitionsByHost(final AssignmentInfo 
assignmentInfo,
-                                               final DataInputStream in) 
throws IOException {
+                                               final DataInputStream in,
+                                               final int length) throws 
IOException {
         assignmentInfo.partitionsByHost = new HashMap<>();
         final int numEntries = in.readInt();
+        if (numEntries < 0 || numEntries > length) {
+            throw new TaskAssignmentException("Corrupted user data byte[].");
+        }
         for (int i = 0; i < numEntries; i++) {
             final HostInfo hostInfo = new HostInfo(in.readUTF(), in.readInt());
-            assignmentInfo.partitionsByHost.put(hostInfo, 
readTopicPartitions(in));
+            assignmentInfo.partitionsByHost.put(hostInfo, 
readTopicPartitions(in, length));
         }
     }
 
-    private static Set<TopicPartition> readTopicPartitions(final 
DataInputStream in) throws IOException {
+    private static Set<TopicPartition> readTopicPartitions(final 
DataInputStream in,
+                                                           final int length) 
throws IOException {
         final int numPartitions = in.readInt();
+        if (numPartitions < 0 || numPartitions > length) {
+            throw new TaskAssignmentException("Corrupted user data byte[].");
+        }
         final Set<TopicPartition> partitions = new HashSet<>(numPartitions);
         for (int j = 0; j < numPartitions; j++) {
             partitions.add(new TopicPartition(in.readUTF(), in.readInt()));
@@ -422,8 +439,12 @@ public class AssignmentInfo {
         return partitions;
     }
 
-    private static Map<Integer, String> decodeTopicIndexAndGet(final 
DataInputStream in) throws IOException {
+    private static Map<Integer, String> decodeTopicIndexAndGet(final 
DataInputStream in,
+                                                               final int 
length) throws IOException {
         final int dictSize = in.readInt();
+        if (dictSize < 0 || dictSize > length) {
+            throw new TaskAssignmentException("Corrupted user data byte[].");
+        }
         final Map<Integer, String> topicIndexDict = new HashMap<>(dictSize);
         for (int i = 0; i < dictSize; i++) {
             topicIndexDict.put(in.readInt(), in.readUTF());
@@ -432,32 +453,42 @@ public class AssignmentInfo {
     }
 
     private static Map<HostInfo, Set<TopicPartition>> 
decodeHostPartitionMapUsingDictionary(final DataInputStream in,
-                                                                               
             final Map<Integer, String> topicIndexDict) throws IOException {
-        final Map<HostInfo, Set<TopicPartition>> hostPartitionMap = new 
HashMap<>();
+                                                                               
             final Map<Integer, String> topicIndexDict,
+                                                                               
             final int length) throws IOException {
         final int numEntries = in.readInt();
+        if (numEntries < 0 || numEntries > length) {
+            throw new TaskAssignmentException("Corrupted user data byte[].");
+        }
+        final Map<HostInfo, Set<TopicPartition>> hostPartitionMap = new 
HashMap<>(numEntries);
         for (int i = 0; i < numEntries; i++) {
             final HostInfo hostInfo = new HostInfo(in.readUTF(), in.readInt());
-            hostPartitionMap.put(hostInfo, readTopicPartitions(in, 
topicIndexDict));
+            hostPartitionMap.put(hostInfo, readTopicPartitions(in, 
topicIndexDict, length));
         }
         return hostPartitionMap;
     }
 
     private static void decodePartitionsByHostUsingDictionary(final 
AssignmentInfo assignmentInfo,
-                                                              final 
DataInputStream in) throws IOException {
-        final Map<Integer, String> topicIndexDict = decodeTopicIndexAndGet(in);
-        assignmentInfo.partitionsByHost = 
decodeHostPartitionMapUsingDictionary(in, topicIndexDict);
+                                                              final 
DataInputStream in,
+                                                              final int 
length) throws IOException {
+        final Map<Integer, String> topicIndexDict = decodeTopicIndexAndGet(in, 
length);
+        assignmentInfo.partitionsByHost = 
decodeHostPartitionMapUsingDictionary(in, topicIndexDict, length);
     }
 
     private static void decodeActiveAndStandbyHostPartitions(final 
AssignmentInfo assignmentInfo,
-                                                             final 
DataInputStream in) throws IOException {
-        final Map<Integer, String> topicIndexDict = decodeTopicIndexAndGet(in);
-        assignmentInfo.partitionsByHost = 
decodeHostPartitionMapUsingDictionary(in, topicIndexDict);
-        assignmentInfo.standbyPartitionsByHost = 
decodeHostPartitionMapUsingDictionary(in, topicIndexDict);
+                                                             final 
DataInputStream in,
+                                                             final int length) 
throws IOException {
+        final Map<Integer, String> topicIndexDict = decodeTopicIndexAndGet(in, 
length);
+        assignmentInfo.partitionsByHost = 
decodeHostPartitionMapUsingDictionary(in, topicIndexDict, length);
+        assignmentInfo.standbyPartitionsByHost = 
decodeHostPartitionMapUsingDictionary(in, topicIndexDict, length);
     }
 
     private static Set<TopicPartition> readTopicPartitions(final 
DataInputStream in,
-                                                           final Map<Integer, 
String> topicIndexDict) throws IOException {
+                                                           final Map<Integer, 
String> topicIndexDict,
+                                                           final int length) 
throws IOException {
         final int numPartitions = in.readInt();
+        if (numPartitions < 0 || numPartitions > length) {
+            throw new TaskAssignmentException("Corrupted user data byte[].");
+        }
         final Set<TopicPartition> partitions = new HashSet<>(numPartitions);
         for (int j = 0; j < numPartitions; j++) {
             partitions.add(new 
TopicPartition(topicIndexDict.get(in.readInt()), in.readInt()));

Reply via email to