Repository: kafka
Updated Branches:
  refs/heads/0.11.0 e07ef8ef5 -> 2d9dedca1


KAFKA-5031; Validate count of records and headers for new message format

https://issues.apache.org/jira/browse/KAFKA-5031

Implements additional check for `DefaultRecordBatch` that compares number of 
records declared in the header with actual number of records. Similarly for 
headers.

Author: gosubpl <[email protected]>

Reviewers: Ismael Juma <[email protected]>, Jason Gustafson <[email protected]>

Closes #3156 from gosubpl/KAFKA-5031

(cherry picked from commit 48de613a90fe91ca56628803fe9f02fdfd99813a)
Signed-off-by: Jason Gustafson <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/2d9dedca
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/2d9dedca
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/2d9dedca

Branch: refs/heads/0.11.0
Commit: 2d9dedca15ef13b059912e5767bd1d7b3c7cc0b2
Parents: e07ef8e
Author: gosubpl <[email protected]>
Authored: Fri Jun 16 21:58:26 2017 -0700
Committer: Jason Gustafson <[email protected]>
Committed: Fri Jun 16 21:59:35 2017 -0700

----------------------------------------------------------------------
 .../kafka/common/record/DefaultRecord.java      | 80 ++++++++++++--------
 .../kafka/common/record/DefaultRecordBatch.java | 39 +++++++++-
 .../common/record/DefaultRecordBatchTest.java   | 59 +++++++++++++++
 .../kafka/common/record/DefaultRecordTest.java  | 58 ++++++++++++++
 .../src/main/scala/kafka/log/LogValidator.scala |  1 +
 5 files changed, 201 insertions(+), 36 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/2d9dedca/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java
----------------------------------------------------------------------
diff --git 
a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java 
b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java
index 0143455..9b7f327 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java
@@ -26,6 +26,7 @@ import org.apache.kafka.common.utils.Utils;
 import java.io.DataInput;
 import java.io.DataOutputStream;
 import java.io.IOException;
+import java.nio.BufferUnderflowException;
 import java.nio.ByteBuffer;
 import java.util.Arrays;
 import java.util.zip.Checksum;
@@ -314,41 +315,53 @@ public class DefaultRecord implements Record {
                                           long baseTimestamp,
                                           int baseSequence,
                                           Long logAppendTime) {
-        byte attributes = buffer.get();
-        long timestampDelta = ByteUtils.readVarlong(buffer);
-        long timestamp = baseTimestamp + timestampDelta;
-        if (logAppendTime != null)
-            timestamp = logAppendTime;
-
-        int offsetDelta = ByteUtils.readVarint(buffer);
-        long offset = baseOffset + offsetDelta;
-        int sequence = baseSequence >= 0 ?
-                DefaultRecordBatch.incrementSequence(baseSequence, 
offsetDelta) :
-                RecordBatch.NO_SEQUENCE;
-
-        ByteBuffer key = null;
-        int keySize = ByteUtils.readVarint(buffer);
-        if (keySize >= 0) {
-            key = buffer.slice();
-            key.limit(keySize);
-            buffer.position(buffer.position() + keySize);
-        }
+        try {
+            int recordStart = buffer.position();
+            byte attributes = buffer.get();
+            long timestampDelta = ByteUtils.readVarlong(buffer);
+            long timestamp = baseTimestamp + timestampDelta;
+            if (logAppendTime != null)
+                timestamp = logAppendTime;
+
+            int offsetDelta = ByteUtils.readVarint(buffer);
+            long offset = baseOffset + offsetDelta;
+            int sequence = baseSequence >= 0 ?
+                    DefaultRecordBatch.incrementSequence(baseSequence, 
offsetDelta) :
+                    RecordBatch.NO_SEQUENCE;
+
+            ByteBuffer key = null;
+            int keySize = ByteUtils.readVarint(buffer);
+            if (keySize >= 0) {
+                key = buffer.slice();
+                key.limit(keySize);
+                buffer.position(buffer.position() + keySize);
+            }
 
-        ByteBuffer value = null;
-        int valueSize = ByteUtils.readVarint(buffer);
-        if (valueSize >= 0) {
-            value = buffer.slice();
-            value.limit(valueSize);
-            buffer.position(buffer.position() + valueSize);
-        }
+            ByteBuffer value = null;
+            int valueSize = ByteUtils.readVarint(buffer);
+            if (valueSize >= 0) {
+                value = buffer.slice();
+                value.limit(valueSize);
+                buffer.position(buffer.position() + valueSize);
+            }
+
+            int numHeaders = ByteUtils.readVarint(buffer);
+            if (numHeaders < 0)
+                throw new InvalidRecordException("Found invalid number of 
record headers " + numHeaders);
 
-        int numHeaders = ByteUtils.readVarint(buffer);
-        if (numHeaders < 0)
-            throw new InvalidRecordException("Found invalid number of record 
headers " + numHeaders);
+            if (numHeaders == 0)
+                return new DefaultRecord(sizeInBytes, attributes, offset, 
timestamp, sequence, key, value, Record.EMPTY_HEADERS);
 
-        if (numHeaders == 0)
-            return new DefaultRecord(sizeInBytes, attributes, offset, 
timestamp, sequence, key, value, Record.EMPTY_HEADERS);
+            Header[] headers = readHeaders(buffer, numHeaders, recordStart, 
sizeInBytes);
 
+            return new DefaultRecord(sizeInBytes, attributes, offset, 
timestamp, sequence, key, value, headers);
+        } catch (BufferUnderflowException | IllegalArgumentException e) {
+            throw new InvalidRecordException("Invalid header data or number of 
headers declared for the record, reason for failure was "
+                    + e.getMessage());
+        }
+    }
+
+    private static Header[] readHeaders(ByteBuffer buffer, int numHeaders, int 
recordStart, int sizeInBytes) {
         Header[] headers = new Header[numHeaders];
         for (int i = 0; i < numHeaders; i++) {
             int headerKeySize = ByteUtils.readVarint(buffer);
@@ -369,7 +382,10 @@ public class DefaultRecord implements Record {
             headers[i] = new RecordHeader(headerKey, headerValue);
         }
 
-        return new DefaultRecord(sizeInBytes, attributes, offset, timestamp, 
sequence, key, value, headers);
+        // validate whether we have read all header bytes in the current record
+        if (buffer.position() - recordStart != sizeInBytes - 
ByteUtils.sizeOfVarint(sizeInBytes))
+            throw new InvalidRecordException("Invalid header data or number of 
headers declared for the record");
+        return headers;
     }
 
     public static int sizeInBytes(int offsetDelta,

http://git-wip-us.apache.org/repos/asf/kafka/blob/2d9dedca/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java
----------------------------------------------------------------------
diff --git 
a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java 
b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java
index 353eb6a..5a7e27a 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java
@@ -24,7 +24,9 @@ import org.apache.kafka.common.utils.CloseableIterator;
 import org.apache.kafka.common.utils.Crc32C;
 
 import java.io.DataInputStream;
+import java.io.EOFException;
 import java.io.IOException;
+import java.nio.BufferUnderflowException;
 import java.nio.ByteBuffer;
 import java.nio.channels.FileChannel;
 import java.util.ArrayList;
@@ -124,7 +126,7 @@ public class DefaultRecordBatch extends AbstractRecordBatch 
implements MutableRe
     public void ensureValid() {
         if (sizeInBytes() < RECORD_BATCH_OVERHEAD)
             throw new InvalidRecordException("Record batch is corrupt (the 
size " + sizeInBytes() +
-                    "is smaller than the minimum allowed overhead " + 
RECORD_BATCH_OVERHEAD + ")");
+                    " is smaller than the minimum allowed overhead " + 
RECORD_BATCH_OVERHEAD + ")");
 
         if (!isValid())
             throw new InvalidRecordException("Record is corrupt (stored crc = 
" + checksum()
@@ -235,7 +237,7 @@ public class DefaultRecordBatch extends AbstractRecordBatch 
implements MutableRe
     }
 
     private CloseableIterator<Record> compressedIterator(BufferSupplier 
bufferSupplier) {
-        ByteBuffer buffer = this.buffer.duplicate();
+        final ByteBuffer buffer = this.buffer.duplicate();
         buffer.position(RECORDS_OFFSET);
         final DataInputStream inputStream = new 
DataInputStream(compressionType().wrapForInput(buffer, magic(),
                 bufferSupplier));
@@ -245,12 +247,23 @@ public class DefaultRecordBatch extends 
AbstractRecordBatch implements MutableRe
             protected Record readNext(long baseOffset, long baseTimestamp, int 
baseSequence, Long logAppendTime) {
                 try {
                     return DefaultRecord.readFrom(inputStream, baseOffset, 
baseTimestamp, baseSequence, logAppendTime);
+                } catch (EOFException e) {
+                    throw new InvalidRecordException("Incorrect declared batch 
size, premature EOF reached");
                 } catch (IOException e) {
                     throw new KafkaException("Failed to decompress record 
stream", e);
                 }
             }
 
             @Override
+            protected boolean ensureNoneRemaining() {
+                try {
+                    return inputStream.read() == -1;
+                } catch (IOException e) {
+                    return false;
+                }
+            }
+
+            @Override
             public void close() {
                 try {
                     inputStream.close();
@@ -267,7 +280,15 @@ public class DefaultRecordBatch extends 
AbstractRecordBatch implements MutableRe
         return new RecordIterator() {
             @Override
             protected Record readNext(long baseOffset, long baseTimestamp, int 
baseSequence, Long logAppendTime) {
-                return DefaultRecord.readFrom(buffer, baseOffset, 
baseTimestamp, baseSequence, logAppendTime);
+                try {
+                    return DefaultRecord.readFrom(buffer, baseOffset, 
baseTimestamp, baseSequence, logAppendTime);
+                } catch (BufferUnderflowException e) {
+                    throw new InvalidRecordException("Incorrect declared batch 
size, premature EOF reached");
+                }
+            }
+            @Override
+            protected boolean ensureNoneRemaining() {
+                return !buffer.hasRemaining();
             }
             @Override
             public void close() {}
@@ -502,11 +523,21 @@ public class DefaultRecordBatch extends 
AbstractRecordBatch implements MutableRe
                 throw new NoSuchElementException();
 
             readRecords++;
-            return readNext(baseOffset, baseTimestamp, baseSequence, 
logAppendTime);
+            Record rec = readNext(baseOffset, baseTimestamp, baseSequence, 
logAppendTime);
+            if (readRecords == numRecords) {
+                // Validate that the actual size of the batch is equal to 
declared size
+                // by checking that after reading declared number of items, 
there no items left
+                // (overflow case, i.e. reading past buffer end is checked 
elsewhere).
+                if (!ensureNoneRemaining())
+                    throw new InvalidRecordException("Incorrect declared batch 
size, records still remaining in file");
+            }
+            return rec;
         }
 
         protected abstract Record readNext(long baseOffset, long 
baseTimestamp, int baseSequence, Long logAppendTime);
 
+        protected abstract boolean ensureNoneRemaining();
+
         @Override
         public void remove() {
             throw new UnsupportedOperationException();

http://git-wip-us.apache.org/repos/asf/kafka/blob/2d9dedca/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java
----------------------------------------------------------------------
diff --git 
a/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java
 
b/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java
index 726b619..a5ede9c 100644
--- 
a/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java
@@ -27,6 +27,7 @@ import java.nio.ByteBuffer;
 import java.util.Arrays;
 import java.util.List;
 
+import static 
org.apache.kafka.common.record.DefaultRecordBatch.RECORDS_COUNT_OFFSET;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
@@ -153,6 +154,50 @@ public class DefaultRecordBatchTest {
     }
 
     @Test(expected = InvalidRecordException.class)
+    public void testInvalidRecordCountTooManyNonCompressedV2() {
+        long now = System.currentTimeMillis();
+        DefaultRecordBatch batch = 
recordsWithInvalidRecordCount(RecordBatch.MAGIC_VALUE_V2, now, 
CompressionType.NONE, 5);
+        // force iteration through the batch to execute validation
+        // batch validation is a part of normal workflow for 
LogValidator.validateMessagesAndAssignOffsets
+        for (Record record: batch) {
+            record.isValid();
+        }
+    }
+
+    @Test(expected = InvalidRecordException.class)
+    public void testInvalidRecordCountTooLittleNonCompressedV2() {
+        long now = System.currentTimeMillis();
+        DefaultRecordBatch batch = 
recordsWithInvalidRecordCount(RecordBatch.MAGIC_VALUE_V2, now, 
CompressionType.NONE, 2);
+        // force iteration through the batch to execute validation
+        // batch validation is a part of normal workflow for 
LogValidator.validateMessagesAndAssignOffsets
+        for (Record record: batch) {
+            record.isValid();
+        }
+    }
+
+    @Test(expected = InvalidRecordException.class)
+    public void testInvalidRecordCountTooManyCompressedV2() {
+        long now = System.currentTimeMillis();
+        DefaultRecordBatch batch = 
recordsWithInvalidRecordCount(RecordBatch.MAGIC_VALUE_V2, now, 
CompressionType.GZIP, 5);
+        // force iteration through the batch to execute validation
+        // batch validation is a part of normal workflow for 
LogValidator.validateMessagesAndAssignOffsets
+        for (Record record: batch) {
+            record.isValid();
+        }
+    }
+
+    @Test(expected = InvalidRecordException.class)
+    public void testInvalidRecordCountTooLittleCompressedV2() {
+        long now = System.currentTimeMillis();
+        DefaultRecordBatch batch = 
recordsWithInvalidRecordCount(RecordBatch.MAGIC_VALUE_V2, now, 
CompressionType.GZIP, 2);
+        // force iteration through the batch to execute validation
+        // batch validation is a part of normal workflow for 
LogValidator.validateMessagesAndAssignOffsets
+        for (Record record: batch) {
+            record.isValid();
+        }
+    }
+
+    @Test(expected = InvalidRecordException.class)
     public void testInvalidCrc() {
         MemoryRecords records = 
MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, 0L,
                 CompressionType.NONE, TimestampType.CREATE_TIME,
@@ -301,4 +346,18 @@ public class DefaultRecordBatchTest {
         assertEquals(4, DefaultRecordBatch.incrementSequence(Integer.MAX_VALUE 
- 5, 10));
     }
 
+    private static DefaultRecordBatch recordsWithInvalidRecordCount(Byte 
magicValue, long timestamp,
+                                              CompressionType codec, int 
invalidCount) {
+        ByteBuffer buf = ByteBuffer.allocate(512);
+        MemoryRecordsBuilder builder = MemoryRecords.builder(buf, magicValue, 
codec, TimestampType.CREATE_TIME, 0L);
+        builder.appendWithOffset(0, timestamp, null, "hello".getBytes());
+        builder.appendWithOffset(1, timestamp, null, "there".getBytes());
+        builder.appendWithOffset(2, timestamp, null, "beautiful".getBytes());
+        MemoryRecords records = builder.build();
+        ByteBuffer buffer = records.buffer();
+        buffer.position(0);
+        buffer.putInt(RECORDS_COUNT_OFFSET, invalidCount);
+        buffer.position(0);
+        return new DefaultRecordBatch(buffer);
+    }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/2d9dedca/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordTest.java
----------------------------------------------------------------------
diff --git 
a/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordTest.java 
b/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordTest.java
index 2c0ef05..b9c923d 100644
--- 
a/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordTest.java
@@ -73,6 +73,64 @@ public class DefaultRecordTest {
         }
     }
 
+    @Test(expected = InvalidRecordException.class)
+    public void testBasicSerdeInvalidHeaderCountTooHigh() throws IOException {
+        Header[] headers = new Header[] {
+            new RecordHeader("foo", "value".getBytes()),
+            new RecordHeader("bar", (byte[]) null),
+            new RecordHeader("\"A\\u00ea\\u00f1\\u00fcC\"", "value".getBytes())
+        };
+
+        SimpleRecord record = new SimpleRecord(15L, "hi".getBytes(), 
"there".getBytes(), headers);
+
+        int baseSequence = 723;
+        long baseOffset = 37;
+        int offsetDelta = 10;
+        long baseTimestamp = System.currentTimeMillis();
+        long timestampDelta = 323;
+
+        ByteBufferOutputStream out = new ByteBufferOutputStream(1024);
+        DefaultRecord.writeTo(new DataOutputStream(out), offsetDelta, 
timestampDelta, record.key(), record.value(),
+                record.headers());
+        ByteBuffer buffer = out.buffer();
+        buffer.flip();
+        buffer.put(14, (byte) 8);
+
+        DefaultRecord logRecord = DefaultRecord.readFrom(buffer, baseOffset, 
baseTimestamp, baseSequence, null);
+        // force iteration through the record to validate the number of headers
+        assertEquals(DefaultRecord.sizeInBytes(offsetDelta, timestampDelta, 
record.key(), record.value(),
+                record.headers()), logRecord.sizeInBytes());
+    }
+
+    @Test(expected = InvalidRecordException.class)
+    public void testBasicSerdeInvalidHeaderCountTooLow() throws IOException {
+        Header[] headers = new Header[] {
+            new RecordHeader("foo", "value".getBytes()),
+            new RecordHeader("bar", (byte[]) null),
+            new RecordHeader("\"A\\u00ea\\u00f1\\u00fcC\"", "value".getBytes())
+        };
+
+        SimpleRecord record = new SimpleRecord(15L, "hi".getBytes(), 
"there".getBytes(), headers);
+
+        int baseSequence = 723;
+        long baseOffset = 37;
+        int offsetDelta = 10;
+        long baseTimestamp = System.currentTimeMillis();
+        long timestampDelta = 323;
+
+        ByteBufferOutputStream out = new ByteBufferOutputStream(1024);
+        DefaultRecord.writeTo(new DataOutputStream(out), offsetDelta, 
timestampDelta, record.key(), record.value(),
+                record.headers());
+        ByteBuffer buffer = out.buffer();
+        buffer.flip();
+        buffer.put(14, (byte) 4);
+
+        DefaultRecord logRecord = DefaultRecord.readFrom(buffer, baseOffset, 
baseTimestamp, baseSequence, null);
+        // force iteration through the record to validate the number of headers
+        assertEquals(DefaultRecord.sizeInBytes(offsetDelta, timestampDelta, 
record.key(), record.value(),
+                record.headers()), logRecord.sizeInBytes());
+    }
+
     @Test
     public void testSerdeNoSequence() throws IOException {
         ByteBuffer key = ByteBuffer.wrap("hi".getBytes());

http://git-wip-us.apache.org/repos/asf/kafka/blob/2d9dedca/core/src/main/scala/kafka/log/LogValidator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/log/LogValidator.scala 
b/core/src/main/scala/kafka/log/LogValidator.scala
index ee5cb58..e7d7963 100644
--- a/core/src/main/scala/kafka/log/LogValidator.scala
+++ b/core/src/main/scala/kafka/log/LogValidator.scala
@@ -35,6 +35,7 @@ private[kafka] object LogValidator extends Logging {
    * 2. When magic value >= 1, inner messages of a compressed message set must 
have monotonically increasing offsets
    *    starting from 0.
    * 3. When magic value >= 1, validate and maybe overwrite timestamps of 
messages.
+   * 4. Declared count of records in DefaultRecordBatch must match number of 
valid records contained therein.
    *
    * This method will convert messages as necessary to the topic's configured 
message format version. If no format
    * conversion or value overwriting is required for messages, this method 
will perform in-place operations to

Reply via email to