[runtime] Add spillable input deserializer
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/be8e1f18 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/be8e1f18 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/be8e1f18 Branch: refs/heads/master Commit: be8e1f181825f0f84a276dc9ec335c364a4b3e38 Parents: bc69b0b Author: Stephan Ewen <[email protected]> Authored: Sun Dec 14 15:39:28 2014 +0100 Committer: Ufuk Celebi <[email protected]> Committed: Wed Jan 21 12:01:35 2015 +0100 ---------------------------------------------------------------------- .../memory/InputViewDataInputStreamWrapper.java | 11 +- .../serialization/SpanningRecordSerializer.java | 1 + ...llingAdaptiveSpanningRecordDeserializer.java | 615 +++++++++++++++++++ .../runtime/util/DataInputDeserializer.java | 2 +- .../SpanningRecordSerializationTest.java | 68 +- .../SpanningRecordSerializerTest.java | 46 +- .../network/serialization/LargeRecordsTest.java | 97 +++ 7 files changed, 794 insertions(+), 46 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/be8e1f18/flink-core/src/main/java/org/apache/flink/core/memory/InputViewDataInputStreamWrapper.java ---------------------------------------------------------------------- diff --git a/flink-core/src/main/java/org/apache/flink/core/memory/InputViewDataInputStreamWrapper.java b/flink-core/src/main/java/org/apache/flink/core/memory/InputViewDataInputStreamWrapper.java index 289d66e..7de1d71 100644 --- a/flink-core/src/main/java/org/apache/flink/core/memory/InputViewDataInputStreamWrapper.java +++ b/flink-core/src/main/java/org/apache/flink/core/memory/InputViewDataInputStreamWrapper.java @@ -18,16 +18,25 @@ package org.apache.flink.core.memory; +import java.io.Closeable; import java.io.DataInputStream; import java.io.EOFException; import java.io.IOException; -public class InputViewDataInputStreamWrapper implements DataInputView { +public class InputViewDataInputStreamWrapper implements DataInputView, Closeable { + private final DataInputStream in; public InputViewDataInputStreamWrapper(DataInputStream in){ this.in = in; } + + @Override + public void close() throws IOException { + in.close(); + } + + // -------------------------------------------------------------------------------------------- @Override public void skipBytesToRead(int numBytes) throws IOException { http://git-wip-us.apache.org/repos/asf/flink/blob/be8e1f18/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java index ab6fe75..4446dbc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java @@ -104,6 +104,7 @@ public class SpanningRecordSerializer<T extends IOReadableWritable> implements R // make sure we don't hold onto the large buffers for too long if (result.isFullRecord()) { + this.serializationBuffer.clear(); this.serializationBuffer.pruneBuffer(); this.dataBuffer = this.serializationBuffer.wrapAsByteBuffer(); } http://git-wip-us.apache.org/repos/asf/flink/blob/be8e1f18/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/serialization/SpillingAdaptiveSpanningRecordDeserializer.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/serialization/SpillingAdaptiveSpanningRecordDeserializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/serialization/SpillingAdaptiveSpanningRecordDeserializer.java new file mode 100644 index 0000000..371ba0a --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/serialization/SpillingAdaptiveSpanningRecordDeserializer.java @@ -0,0 +1,615 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.serialization; + +import java.io.BufferedInputStream; +import java.io.DataInputStream; +import java.io.EOFException; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.io.UTFDataFormatException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.util.Random; + +import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.configuration.GlobalConfiguration; +import org.apache.flink.core.io.IOReadableWritable; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.InputViewDataInputStreamWrapper; +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.util.StringUtils; + +/** + * @param <T> The type of the record to be deserialized. + */ +public class SpillingAdaptiveSpanningRecordDeserializer<T extends IOReadableWritable> implements RecordDeserializer<T> { + + private static final int THRESHOLD_FOR_SPILLING = 5 * 1024 * 1024; // 5 MiBytes + + private final NonSpanningWrapper nonSpanningWrapper; + + private final SpanningWrapper spanningWrapper; + + public SpillingAdaptiveSpanningRecordDeserializer() { + + String tempDirString = GlobalConfiguration.getString( + ConfigConstants.TASK_MANAGER_TMP_DIR_KEY, + ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH); + String[] directories = tempDirString.split(",|" + File.pathSeparator); + + this.nonSpanningWrapper = new NonSpanningWrapper(); + this.spanningWrapper = new SpanningWrapper(directories); + } + + @Override + public void setNextMemorySegment(MemorySegment segment, int numBytes) throws IOException { + // check if some spanning record deserialization is pending + if (this.spanningWrapper.getNumGatheredBytes() > 0) { + this.spanningWrapper.addNextChunkFromMemorySegment(segment, numBytes); + } + else { + this.nonSpanningWrapper.initializeFromMemorySegment(segment, 0, numBytes); + } + } + + @Override + public DeserializationResult getNextRecord(T target) throws IOException { + // always check the non-spanning wrapper first. + // this should be the majority of the cases for small records + // for large records, this portion of the work is very small in comparison anyways + + int nonSpanningRemaining = this.nonSpanningWrapper.remaining(); + + // check if we can get a full length; + if (nonSpanningRemaining >= 4) { + int len = this.nonSpanningWrapper.readInt(); + + if (len <= nonSpanningRemaining - 4) { + // we can get a full record from here + target.read(this.nonSpanningWrapper); + + return (this.nonSpanningWrapper.remaining() == 0) ? + DeserializationResult.LAST_RECORD_FROM_BUFFER : + DeserializationResult.INTERMEDIATE_RECORD_FROM_BUFFER; + } else { + // we got the length, but we need the rest from the spanning deserializer + // and need to wait for more buffers + this.spanningWrapper.initializeWithPartialRecord(this.nonSpanningWrapper, len); + this.nonSpanningWrapper.clear(); + return DeserializationResult.PARTIAL_RECORD; + } + } else if (nonSpanningRemaining > 0) { + // we have an incomplete length + // add our part of the length to the length buffer + this.spanningWrapper.initializeWithPartialLength(this.nonSpanningWrapper); + this.nonSpanningWrapper.clear(); + return DeserializationResult.PARTIAL_RECORD; + } + + // spanning record case + if (this.spanningWrapper.hasFullRecord()) { + // get the full record + target.read(this.spanningWrapper.getInputView()); + + // move the remainder to the non-spanning wrapper + // this does not copy it, only sets the memory segment + this.spanningWrapper.moveRemainderToNonSpanningDeserializer(this.nonSpanningWrapper); + this.spanningWrapper.clear(); + + return (this.nonSpanningWrapper.remaining() == 0) ? + DeserializationResult.LAST_RECORD_FROM_BUFFER : + DeserializationResult.INTERMEDIATE_RECORD_FROM_BUFFER; + } else { + return DeserializationResult.PARTIAL_RECORD; + } + } + + @Override + public void clear() { + this.nonSpanningWrapper.clear(); + this.spanningWrapper.clear(); + } + + @Override + public boolean hasUnfinishedData() { + return this.nonSpanningWrapper.remaining() > 0 || this.spanningWrapper.getNumGatheredBytes() > 0; + } + + // ----------------------------------------------------------------------------------------------------------------- + + private static final class NonSpanningWrapper implements DataInputView { + + private MemorySegment segment; + + private int limit; + + private int position; + + private byte[] utfByteBuffer; // reusable byte buffer for utf-8 decoding + private char[] utfCharBuffer; // reusable char buffer for utf-8 decoding + + int remaining() { + return this.limit - this.position; + } + + void clear() { + this.segment = null; + this.limit = 0; + this.position = 0; + } + + void initializeFromMemorySegment(MemorySegment seg, int position, int leftOverLimit) { + this.segment = seg; + this.position = position; + this.limit = leftOverLimit; + } + + // ------------------------------------------------------------------------------------------------------------- + // DataInput specific methods + // ------------------------------------------------------------------------------------------------------------- + + @Override + public final void readFully(byte[] b) throws IOException { + readFully(b, 0, b.length); + } + + @Override + public final void readFully(byte[] b, int off, int len) throws IOException { + if (off < 0 || len < 0 || off + len > b.length) { + throw new IndexOutOfBoundsException(); + } + + this.segment.get(this.position, b, off, len); + this.position += len; + } + + @Override + public final boolean readBoolean() throws IOException { + return readByte() == 1; + } + + @Override + public final byte readByte() throws IOException { + return this.segment.get(this.position++); + } + + @Override + public final int readUnsignedByte() throws IOException { + return readByte() & 0xff; + } + + @Override + public final short readShort() throws IOException { + final short v = this.segment.getShort(this.position); + this.position += 2; + return v; + } + + @Override + public final int readUnsignedShort() throws IOException { + final int v = this.segment.getShort(this.position) & 0xffff; + this.position += 2; + return v; + } + + @Override + public final char readChar() throws IOException { + final char v = this.segment.getChar(this.position); + this.position += 2; + return v; + } + + @Override + public final int readInt() throws IOException { + final int v = this.segment.getIntBigEndian(this.position); + this.position += 4; + return v; + } + + @Override + public final long readLong() throws IOException { + final long v = this.segment.getLongBigEndian(this.position); + this.position += 8; + return v; + } + + @Override + public final float readFloat() throws IOException { + return Float.intBitsToFloat(readInt()); + } + + @Override + public final double readDouble() throws IOException { + return Double.longBitsToDouble(readLong()); + } + + @Override + public final String readLine() throws IOException { + final StringBuilder bld = new StringBuilder(32); + + try { + int b; + while ((b = readUnsignedByte()) != '\n') { + if (b != '\r') { + bld.append((char) b); + } + } + } + catch (EOFException eofex) {} + + if (bld.length() == 0) { + return null; + } + + // trim a trailing carriage return + int len = bld.length(); + if (len > 0 && bld.charAt(len - 1) == '\r') { + bld.setLength(len - 1); + } + return bld.toString(); + } + + @Override + public final String readUTF() throws IOException { + final int utflen = readUnsignedShort(); + + final byte[] bytearr; + final char[] chararr; + + if (this.utfByteBuffer == null || this.utfByteBuffer.length < utflen) { + bytearr = new byte[utflen]; + this.utfByteBuffer = bytearr; + } else { + bytearr = this.utfByteBuffer; + } + if (this.utfCharBuffer == null || this.utfCharBuffer.length < utflen) { + chararr = new char[utflen]; + this.utfCharBuffer = chararr; + } else { + chararr = this.utfCharBuffer; + } + + int c, char2, char3; + int count = 0; + int chararr_count = 0; + + readFully(bytearr, 0, utflen); + + while (count < utflen) { + c = (int) bytearr[count] & 0xff; + if (c > 127) { + break; + } + count++; + chararr[chararr_count++] = (char) c; + } + + while (count < utflen) { + c = (int) bytearr[count] & 0xff; + switch (c >> 4) { + case 0: + case 1: + case 2: + case 3: + case 4: + case 5: + case 6: + case 7: + count++; + chararr[chararr_count++] = (char) c; + break; + case 12: + case 13: + count += 2; + if (count > utflen) { + throw new UTFDataFormatException("malformed input: partial character at end"); + } + char2 = (int) bytearr[count - 1]; + if ((char2 & 0xC0) != 0x80) { + throw new UTFDataFormatException("malformed input around byte " + count); + } + chararr[chararr_count++] = (char) (((c & 0x1F) << 6) | (char2 & 0x3F)); + break; + case 14: + count += 3; + if (count > utflen) { + throw new UTFDataFormatException("malformed input: partial character at end"); + } + char2 = (int) bytearr[count - 2]; + char3 = (int) bytearr[count - 1]; + if (((char2 & 0xC0) != 0x80) || ((char3 & 0xC0) != 0x80)) { + throw new UTFDataFormatException("malformed input around byte " + (count - 1)); + } + chararr[chararr_count++] = (char) (((c & 0x0F) << 12) | ((char2 & 0x3F) << 6) | ((char3 & 0x3F) << 0)); + break; + default: + throw new UTFDataFormatException("malformed input around byte " + count); + } + } + // The number of chars produced may be less than utflen + return new String(chararr, 0, chararr_count); + } + + @Override + public final int skipBytes(int n) throws IOException { + if (n < 0) { + throw new IllegalArgumentException(); + } + + int toSkip = Math.min(n, remaining()); + this.position += toSkip; + return toSkip; + } + + @Override + public void skipBytesToRead(int numBytes) throws IOException { + int skippedBytes = skipBytes(numBytes); + + if(skippedBytes < numBytes){ + throw new EOFException("Could not skip " + numBytes + " bytes."); + } + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + if(b == null){ + throw new NullPointerException("Byte array b cannot be null."); + } + + if(off < 0){ + throw new IllegalArgumentException("The offset off cannot be negative."); + } + + if(len < 0){ + throw new IllegalArgumentException("The length len cannot be negative."); + } + + int toRead = Math.min(len, remaining()); + this.segment.get(this.position,b,off, toRead); + this.position += toRead; + + return toRead; + } + + @Override + public int read(byte[] b) throws IOException { + return read(b, 0, b.length); + } + } + + // ----------------------------------------------------------------------------------------------------------------- + + private static final class SpanningWrapper { + + private final byte[] initialBuffer = new byte[1024]; + + private final String[] tempDirs; + + private final Random rnd = new Random(); + + private final DataInputDeserializer serializationReadBuffer; + + private final ByteBuffer lengthBuffer; + + private FileChannel spillingChannel; + + private byte[] buffer; + + private int recordLength; + + private int accumulatedRecordBytes; + + private MemorySegment leftOverData; + + private int leftOverStart; + + private int leftOverLimit; + + private File spillFile; + + private InputViewDataInputStreamWrapper spillFileReader; + + public SpanningWrapper(String[] tempDirs) { + this.tempDirs = tempDirs; + + this.lengthBuffer = ByteBuffer.allocate(4); + this.lengthBuffer.order(ByteOrder.BIG_ENDIAN); + + this.recordLength = -1; + + this.serializationReadBuffer = new DataInputDeserializer(); + this.buffer = initialBuffer; + } + + private void initializeWithPartialRecord(NonSpanningWrapper partial, int nextRecordLength) throws IOException { + // set the length and copy what is available to the buffer + this.recordLength = nextRecordLength; + + final int numBytesChunk = partial.remaining(); + + if (nextRecordLength > THRESHOLD_FOR_SPILLING) { + // create a spilling channel and put the data there + this.spillingChannel = createSpillingChannel(); + + ByteBuffer toWrite = partial.segment.wrap(partial.position, numBytesChunk); + this.spillingChannel.write(toWrite); + } + else { + // collect in memory + ensureBufferCapacity(numBytesChunk); + partial.segment.get(partial.position, buffer, 0, numBytesChunk); + } + + this.accumulatedRecordBytes = numBytesChunk; + } + + private void initializeWithPartialLength(NonSpanningWrapper partial) throws IOException { + // copy what we have to the length buffer + partial.segment.get(partial.position, this.lengthBuffer, partial.remaining()); + } + + private void addNextChunkFromMemorySegment(MemorySegment segment, int numBytesInSegment) throws IOException { + int segmentPosition = 0; + + // check where to go. if we have a partial length, we need to complete it first + if (this.lengthBuffer.position() > 0) { + int toPut = Math.min(this.lengthBuffer.remaining(), numBytesInSegment); + segment.get(0, this.lengthBuffer, toPut); + + // did we complete the length? + if (this.lengthBuffer.hasRemaining()) { + return; + } else { + this.recordLength = this.lengthBuffer.getInt(0); + this.lengthBuffer.clear(); + segmentPosition = toPut; + + if (this.recordLength > THRESHOLD_FOR_SPILLING) { + this.spillingChannel = createSpillingChannel(); + } + } + } + + // copy as much as we need or can for this next spanning record + int needed = this.recordLength - this.accumulatedRecordBytes; + int available = numBytesInSegment - segmentPosition; + int toCopy = Math.min(needed, available); + + if (spillingChannel != null) { + // spill to file + ByteBuffer toWrite = segment.wrap(segmentPosition, toCopy); + this.spillingChannel.write(toWrite); + } + else { + ensureBufferCapacity(accumulatedRecordBytes + toCopy); + segment.get(segmentPosition, buffer, this.accumulatedRecordBytes, toCopy); + } + + this.accumulatedRecordBytes += toCopy; + + if (toCopy < available) { + // there is more data in the segment + this.leftOverData = segment; + this.leftOverStart = segmentPosition + toCopy; + this.leftOverLimit = numBytesInSegment; + } + + if (accumulatedRecordBytes == recordLength) { + // we have the full record + if (spillingChannel == null) { + this.serializationReadBuffer.setBuffer(buffer, 0, recordLength); + } + else { + spillingChannel.close(); + + DataInputStream inStream = new DataInputStream(new BufferedInputStream(new FileInputStream(spillFile), 2 * 1024 * 1024)); + this.spillFileReader = new InputViewDataInputStreamWrapper(inStream); + } + } + } + + private void moveRemainderToNonSpanningDeserializer(NonSpanningWrapper deserializer) { + deserializer.clear(); + + if (leftOverData != null) { + deserializer.initializeFromMemorySegment(leftOverData, leftOverStart, leftOverLimit); + } + } + + private boolean hasFullRecord() { + return this.recordLength >= 0 && this.accumulatedRecordBytes >= this.recordLength; + } + + private int getNumGatheredBytes() { + return this.accumulatedRecordBytes + (this.recordLength >= 0 ? 4 : lengthBuffer.position()); + } + + public void clear() { + this.buffer = initialBuffer; + this.serializationReadBuffer.releaseArrays(); + + this.recordLength = -1; + this.lengthBuffer.clear(); + this.leftOverData = null; + this.accumulatedRecordBytes = 0; + + if (spillingChannel != null) { + try { + spillingChannel.close(); + } + catch (Throwable t) { + // ignore + } + spillingChannel = null; + } + if (spillFileReader != null) { + try { + spillFileReader.close(); + } + catch (Throwable t) { + // ignore + } + spillFileReader = null; + } + if (spillFile != null) { + spillFile.delete(); + spillFile = null; + } + } + + public DataInputView getInputView() { + if (spillFileReader == null) { + return serializationReadBuffer; + } + else { + return spillFileReader; + } + } + + private void ensureBufferCapacity(int minLength) { + if (buffer.length < minLength) { + byte[] newBuffer = new byte[Math.max(minLength, buffer.length * 2)]; + System.arraycopy(buffer, 0, newBuffer, 0, accumulatedRecordBytes); + buffer = newBuffer; + } + } + + @SuppressWarnings("resource") + private FileChannel createSpillingChannel() throws IOException { + if (spillFile != null) { + throw new IllegalStateException("Spilling file already exists."); + } + + String directory = tempDirs[rnd.nextInt(tempDirs.length)]; + spillFile = new File(directory, randomString(rnd) + ".inputchannel"); + + return new RandomAccessFile(spillFile, "rw").getChannel(); + } + + private static String randomString(Random random) { + final byte[] bytes = new byte[20]; + random.nextBytes(bytes); + return StringUtils.byteToHexString(bytes); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/be8e1f18/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java index 793a2e3..35b6f0d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java @@ -71,7 +71,7 @@ public class DataInputDeserializer implements DataInputView { throw new NullPointerException(); } - if (start < 0 || len < 0 || start + len >= buffer.length) { + if (start < 0 || len < 0 || start + len > buffer.length) { throw new IllegalArgumentException(); } http://git-wip-us.apache.org/repos/asf/flink/blob/be8e1f18/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializationTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializationTest.java index 1e59463..6ceb05a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializationTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializationTest.java @@ -16,20 +16,21 @@ * limitations under the License. */ +package org.apache.flink.runtime.io.network.serialization; -package org.apache.flink.runtime.io.network.api.serialization; - -import org.apache.flink.runtime.io.network.buffer.BufferRecycler; import org.junit.Assert; import org.apache.flink.core.memory.MemorySegment; -import org.apache.flink.runtime.io.network.buffer.Buffer; -import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult; -import org.apache.flink.runtime.io.network.api.serialization.types.SerializationTestType; -import org.apache.flink.runtime.io.network.api.serialization.types.SerializationTestTypeFactory; -import org.apache.flink.runtime.io.network.api.serialization.types.Util; +import org.apache.flink.runtime.io.network.Buffer; +import org.apache.flink.runtime.io.network.serialization.AdaptiveSpanningRecordDeserializer; +import org.apache.flink.runtime.io.network.serialization.RecordDeserializer; +import org.apache.flink.runtime.io.network.serialization.RecordSerializer; +import org.apache.flink.runtime.io.network.serialization.SpanningRecordSerializer; +import org.apache.flink.runtime.io.network.serialization.RecordDeserializer.DeserializationResult; +import org.apache.flink.runtime.io.network.serialization.types.SerializationTestType; +import org.apache.flink.runtime.io.network.serialization.types.SerializationTestTypeFactory; +import org.apache.flink.runtime.io.network.serialization.types.Util; import org.junit.Test; -import org.mockito.Mockito; import java.util.ArrayDeque; @@ -41,8 +42,10 @@ public class SpanningRecordSerializationTest { final int NUM_VALUES = 10; try { - test(Util.randomRecords(NUM_VALUES, SerializationTestTypeFactory.INT), SEGMENT_SIZE); - } catch (Exception e) { + testNonSpillingDeserializer(Util.randomRecords(NUM_VALUES, SerializationTestTypeFactory.INT), SEGMENT_SIZE); + testSpillingDeserializer(Util.randomRecords(NUM_VALUES, SerializationTestTypeFactory.INT), SEGMENT_SIZE); + } + catch (Exception e) { e.printStackTrace(); Assert.fail("Test encountered an unexpected exception."); } @@ -54,8 +57,10 @@ public class SpanningRecordSerializationTest { final int NUM_VALUES = 64; try { - test(Util.randomRecords(NUM_VALUES, SerializationTestTypeFactory.INT), SEGMENT_SIZE); - } catch (Exception e) { + testNonSpillingDeserializer(Util.randomRecords(NUM_VALUES, SerializationTestTypeFactory.INT), SEGMENT_SIZE); + testSpillingDeserializer(Util.randomRecords(NUM_VALUES, SerializationTestTypeFactory.INT), SEGMENT_SIZE); + } + catch (Exception e) { e.printStackTrace(); Assert.fail("Test encountered an unexpected exception."); } @@ -67,8 +72,10 @@ public class SpanningRecordSerializationTest { final int NUM_VALUES = 248; try { - test(Util.randomRecords(NUM_VALUES, SerializationTestTypeFactory.INT), SEGMENT_SIZE); - } catch (Exception e) { + testNonSpillingDeserializer(Util.randomRecords(NUM_VALUES, SerializationTestTypeFactory.INT), SEGMENT_SIZE); + testSpillingDeserializer(Util.randomRecords(NUM_VALUES, SerializationTestTypeFactory.INT), SEGMENT_SIZE); + } + catch (Exception e) { e.printStackTrace(); Assert.fail("Test encountered an unexpected exception."); } @@ -80,8 +87,10 @@ public class SpanningRecordSerializationTest { final int NUM_VALUES = 10000; try { - test(Util.randomRecords(NUM_VALUES), SEGMENT_SIZE); - } catch (Exception e) { + testNonSpillingDeserializer(Util.randomRecords(NUM_VALUES), SEGMENT_SIZE); + testSpillingDeserializer(Util.randomRecords(NUM_VALUES), SEGMENT_SIZE); + } + catch (Exception e) { e.printStackTrace(); Assert.fail("Test encountered an unexpected exception."); } @@ -89,6 +98,20 @@ public class SpanningRecordSerializationTest { // ----------------------------------------------------------------------------------------------------------------- + private void testNonSpillingDeserializer(Util.MockRecords records, int segmentSize) throws Exception { + RecordSerializer<SerializationTestType> serializer = new SpanningRecordSerializer<SerializationTestType>(); + RecordDeserializer<SerializationTestType> deserializer = new AdaptiveSpanningRecordDeserializer<SerializationTestType>(); + + test(records, segmentSize, serializer, deserializer); + } + + private void testSpillingDeserializer(Util.MockRecords records, int segmentSize) throws Exception { + RecordSerializer<SerializationTestType> serializer = new SpanningRecordSerializer<SerializationTestType>(); + RecordDeserializer<SerializationTestType> deserializer = new SpillingAdaptiveSpanningRecordDeserializer<SerializationTestType>(); + + test(records, segmentSize, serializer, deserializer); + } + /** * Iterates over the provided records and tests whether {@link SpanningRecordSerializer} and {@link AdaptiveSpanningRecordDeserializer} * interact as expected. @@ -98,13 +121,14 @@ public class SpanningRecordSerializationTest { * @param records records to test * @param segmentSize size for the {@link MemorySegment} */ - private void test (Util.MockRecords records, int segmentSize) throws Exception { + private void test(Util.MockRecords records, int segmentSize, + RecordSerializer<SerializationTestType> serializer, + RecordDeserializer<SerializationTestType> deserializer) + throws Exception + { final int SERIALIZATION_OVERHEAD = 4; // length encoding - final RecordSerializer<SerializationTestType> serializer = new SpanningRecordSerializer<SerializationTestType>(); - final RecordDeserializer<SerializationTestType> deserializer = new AdaptiveSpanningRecordDeserializer<SerializationTestType>(); - - final Buffer buffer = new Buffer(new MemorySegment(new byte[segmentSize]), Mockito.mock(BufferRecycler.class)); + final Buffer buffer = new Buffer(new MemorySegment(new byte[segmentSize]), segmentSize, null); final ArrayDeque<SerializationTestType> serializedRecords = new ArrayDeque<SerializationTestType>(); http://git-wip-us.apache.org/repos/asf/flink/blob/be8e1f18/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializerTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializerTest.java index 568a599..920d683 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializerTest.java @@ -16,22 +16,19 @@ * limitations under the License. */ +package org.apache.flink.runtime.io.network.serialization; -package org.apache.flink.runtime.io.network.api.serialization; - -import org.apache.flink.runtime.io.network.buffer.BufferRecycler; import org.junit.Assert; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.MemorySegment; -import org.apache.flink.runtime.io.network.buffer.Buffer; -import org.apache.flink.runtime.io.network.api.serialization.RecordSerializer.SerializationResult; -import org.apache.flink.runtime.io.network.api.serialization.types.SerializationTestType; -import org.apache.flink.runtime.io.network.api.serialization.types.SerializationTestTypeFactory; -import org.apache.flink.runtime.io.network.api.serialization.types.Util; +import org.apache.flink.runtime.io.network.Buffer; +import org.apache.flink.runtime.io.network.serialization.RecordSerializer.SerializationResult; +import org.apache.flink.runtime.io.network.serialization.types.SerializationTestType; +import org.apache.flink.runtime.io.network.serialization.types.SerializationTestTypeFactory; +import org.apache.flink.runtime.io.network.serialization.types.Util; import org.junit.Test; -import org.mockito.Mockito; import java.io.IOException; import java.util.Random; @@ -43,7 +40,7 @@ public class SpanningRecordSerializerTest { final int SEGMENT_SIZE = 16; final SpanningRecordSerializer<SerializationTestType> serializer = new SpanningRecordSerializer<SerializationTestType>(); - final Buffer buffer = new Buffer(new MemorySegment(new byte[SEGMENT_SIZE]), Mockito.mock(BufferRecycler.class)); + final Buffer buffer = new Buffer(new MemorySegment(new byte[SEGMENT_SIZE]), SEGMENT_SIZE, null); final SerializationTestType randomIntRecord = Util.randomRecord(SerializationTestTypeFactory.INT); Assert.assertFalse(serializer.hasData()); @@ -65,10 +62,11 @@ public class SpanningRecordSerializerTest { serializer.addRecord(randomIntRecord); Assert.assertTrue(serializer.hasData()); - } catch (IOException e) { + } + catch (Exception e) { e.printStackTrace(); + Assert.fail(e.getMessage()); } - } @Test @@ -76,7 +74,7 @@ public class SpanningRecordSerializerTest { final int SEGMENT_SIZE = 11; final SpanningRecordSerializer<SerializationTestType> serializer = new SpanningRecordSerializer<SerializationTestType>(); - final Buffer buffer = new Buffer(new MemorySegment(new byte[SEGMENT_SIZE]), Mockito.mock(BufferRecycler.class)); + final Buffer buffer = new Buffer(new MemorySegment(new byte[SEGMENT_SIZE]), SEGMENT_SIZE, null); try { Assert.assertEquals(SerializationResult.FULL_RECORD, serializer.setNextBuffer(buffer)); @@ -97,12 +95,10 @@ public class SpanningRecordSerializerTest { } @Override - public void write(DataOutputView out) throws IOException { - } + public void write(DataOutputView out) {} @Override - public void read(DataInputView in) throws IOException { - } + public void read(DataInputView in) {} @Override public int hashCode() { @@ -126,8 +122,10 @@ public class SpanningRecordSerializerTest { result = serializer.setNextBuffer(buffer); Assert.assertEquals(SerializationResult.FULL_RECORD, result); - } catch (IOException e) { + } + catch (Exception e) { e.printStackTrace(); + Assert.fail(e.getMessage()); } } @@ -138,7 +136,8 @@ public class SpanningRecordSerializerTest { try { test(Util.randomRecords(NUM_VALUES, SerializationTestTypeFactory.INT), SEGMENT_SIZE); - } catch (Exception e) { + } + catch (Exception e) { e.printStackTrace(); Assert.fail("Test encountered an unexpected exception."); } @@ -151,7 +150,8 @@ public class SpanningRecordSerializerTest { try { test(Util.randomRecords(NUM_VALUES, SerializationTestTypeFactory.INT), SEGMENT_SIZE); - } catch (Exception e) { + } + catch (Exception e) { e.printStackTrace(); Assert.fail("Test encountered an unexpected exception."); } @@ -164,7 +164,8 @@ public class SpanningRecordSerializerTest { try { test(Util.randomRecords(NUM_VALUES, SerializationTestTypeFactory.INT), SEGMENT_SIZE); - } catch (Exception e) { + } + catch (Exception e) { e.printStackTrace(); Assert.fail("Test encountered an unexpected exception."); } @@ -177,7 +178,8 @@ public class SpanningRecordSerializerTest { try { test(Util.randomRecords(NUM_VALUES), SEGMENT_SIZE); - } catch (Exception e) { + } + catch (Exception e) { e.printStackTrace(); Assert.fail("Test encountered an unexpected exception."); } http://git-wip-us.apache.org/repos/asf/flink/blob/be8e1f18/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/serialization/LargeRecordsTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/serialization/LargeRecordsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/serialization/LargeRecordsTest.java index 6000fee..6c1fd64 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/serialization/LargeRecordsTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/serialization/LargeRecordsTest.java @@ -129,4 +129,101 @@ public class LargeRecordsTest { fail(e.getMessage()); } } + + @Test + public void testHandleMixedLargeRecordsSpillingAdaptiveSerializer() { + try { + final int NUM_RECORDS = 99; + final int SEGMENT_SIZE = 32 * 1024; + + final RecordSerializer<SerializationTestType> serializer = new SpanningRecordSerializer<SerializationTestType>(); + final RecordDeserializer<SerializationTestType> deserializer = new SpillingAdaptiveSpanningRecordDeserializer<SerializationTestType>(); + + final Buffer buffer = new Buffer(new MemorySegment(new byte[SEGMENT_SIZE]), SEGMENT_SIZE, null); + + List<SerializationTestType> originalRecords = new ArrayList<SerializationTestType>(); + List<SerializationTestType> deserializedRecords = new ArrayList<SerializationTestType>(); + + LargeObjectType genLarge = new LargeObjectType(); + + Random rnd = new Random(); + + for (int i = 0; i < NUM_RECORDS; i++) { + if (i % 2 == 0) { + originalRecords.add(new IntType(42)); + deserializedRecords.add(new IntType()); + } else { + originalRecords.add(genLarge.getRandom(rnd)); + deserializedRecords.add(new LargeObjectType()); + } + } + + // ------------------------------------------------------------------------------------------------------------- + + serializer.setNextBuffer(buffer); + + int numRecordsDeserialized = 0; + + for (SerializationTestType record : originalRecords) { + + // serialize record + if (serializer.addRecord(record).isFullBuffer()) { + + // buffer is full => move to deserializer + deserializer.setNextMemorySegment(serializer.getCurrentBuffer().getMemorySegment(), SEGMENT_SIZE); + + // deserialize records, as many complete as there are + while (numRecordsDeserialized < deserializedRecords.size()) { + SerializationTestType next = deserializedRecords.get(numRecordsDeserialized); + + if (deserializer.getNextRecord(next).isFullRecord()) { + assertEquals(originalRecords.get(numRecordsDeserialized), next); + numRecordsDeserialized++; + } else { + break; + } + } + + // move buffers as long as necessary (for long records) + while (serializer.setNextBuffer(buffer).isFullBuffer()) { + deserializer.setNextMemorySegment(serializer.getCurrentBuffer().getMemorySegment(), SEGMENT_SIZE); + } + + // deserialize records, as many as there are in the last buffer + while (numRecordsDeserialized < deserializedRecords.size()) { + SerializationTestType next = deserializedRecords.get(numRecordsDeserialized); + + if (deserializer.getNextRecord(next).isFullRecord()) { + assertEquals(originalRecords.get(numRecordsDeserialized), next); + numRecordsDeserialized++; + } else { + break; + } + } + } + } + + // move the last (incomplete buffer) + Buffer last = serializer.getCurrentBuffer(); + deserializer.setNextMemorySegment(last.getMemorySegment(), last.size()); + serializer.clear(); + + // deserialize records, as many as there are in the last buffer + while (numRecordsDeserialized < deserializedRecords.size()) { + SerializationTestType next = deserializedRecords.get(numRecordsDeserialized); + + assertTrue(deserializer.getNextRecord(next).isFullRecord()); + assertEquals(originalRecords.get(numRecordsDeserialized), next); + numRecordsDeserialized++; + } + + // might be that the last big records has not yet been fully moved, and a small one is missing + assertFalse(serializer.hasData()); + assertFalse(deserializer.hasUnfinishedData()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } }
