[runtime] Add spillable input deserializer

Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/be8e1f18
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/be8e1f18
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/be8e1f18

Branch: refs/heads/master
Commit: be8e1f181825f0f84a276dc9ec335c364a4b3e38
Parents: bc69b0b
Author: Stephan Ewen <[email protected]>
Authored: Sun Dec 14 15:39:28 2014 +0100
Committer: Ufuk Celebi <[email protected]>
Committed: Wed Jan 21 12:01:35 2015 +0100

----------------------------------------------------------------------
 .../memory/InputViewDataInputStreamWrapper.java |  11 +-
 .../serialization/SpanningRecordSerializer.java |   1 +
 ...llingAdaptiveSpanningRecordDeserializer.java | 615 +++++++++++++++++++
 .../runtime/util/DataInputDeserializer.java     |   2 +-
 .../SpanningRecordSerializationTest.java        |  68 +-
 .../SpanningRecordSerializerTest.java           |  46 +-
 .../network/serialization/LargeRecordsTest.java |  97 +++
 7 files changed, 794 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/be8e1f18/flink-core/src/main/java/org/apache/flink/core/memory/InputViewDataInputStreamWrapper.java
----------------------------------------------------------------------
diff --git 
a/flink-core/src/main/java/org/apache/flink/core/memory/InputViewDataInputStreamWrapper.java
 
b/flink-core/src/main/java/org/apache/flink/core/memory/InputViewDataInputStreamWrapper.java
index 289d66e..7de1d71 100644
--- 
a/flink-core/src/main/java/org/apache/flink/core/memory/InputViewDataInputStreamWrapper.java
+++ 
b/flink-core/src/main/java/org/apache/flink/core/memory/InputViewDataInputStreamWrapper.java
@@ -18,16 +18,25 @@
 
 package org.apache.flink.core.memory;
 
+import java.io.Closeable;
 import java.io.DataInputStream;
 import java.io.EOFException;
 import java.io.IOException;
 
-public class InputViewDataInputStreamWrapper implements DataInputView {
+public class InputViewDataInputStreamWrapper implements DataInputView, 
Closeable {
+       
        private final DataInputStream in;
 
        public InputViewDataInputStreamWrapper(DataInputStream in){
                this.in = in;
        }
+       
+       @Override
+       public void close() throws IOException {
+               in.close();
+       }
+       
+       // 
--------------------------------------------------------------------------------------------
 
        @Override
        public void skipBytesToRead(int numBytes) throws IOException {

http://git-wip-us.apache.org/repos/asf/flink/blob/be8e1f18/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java
index ab6fe75..4446dbc 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java
@@ -104,6 +104,7 @@ public class SpanningRecordSerializer<T extends 
IOReadableWritable> implements R
                
                // make sure we don't hold onto the large buffers for too long
                if (result.isFullRecord()) {
+                       this.serializationBuffer.clear();
                        this.serializationBuffer.pruneBuffer();
                        this.dataBuffer = 
this.serializationBuffer.wrapAsByteBuffer();
                }

http://git-wip-us.apache.org/repos/asf/flink/blob/be8e1f18/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/serialization/SpillingAdaptiveSpanningRecordDeserializer.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/serialization/SpillingAdaptiveSpanningRecordDeserializer.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/serialization/SpillingAdaptiveSpanningRecordDeserializer.java
new file mode 100644
index 0000000..371ba0a
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/serialization/SpillingAdaptiveSpanningRecordDeserializer.java
@@ -0,0 +1,615 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.serialization;
+
+import java.io.BufferedInputStream;
+import java.io.DataInputStream;
+import java.io.EOFException;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.RandomAccessFile;
+import java.io.UTFDataFormatException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.channels.FileChannel;
+import java.util.Random;
+
+import org.apache.flink.configuration.ConfigConstants;
+import org.apache.flink.configuration.GlobalConfiguration;
+import org.apache.flink.core.io.IOReadableWritable;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.InputViewDataInputStreamWrapper;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.util.StringUtils;
+
+/**
+ * @param <T> The type of the record to be deserialized.
+ */
+public class SpillingAdaptiveSpanningRecordDeserializer<T extends 
IOReadableWritable> implements RecordDeserializer<T> {
+       
+       private static final int THRESHOLD_FOR_SPILLING = 5 * 1024 * 1024; // 5 
MiBytes
+       
+       private final NonSpanningWrapper nonSpanningWrapper;
+       
+       private final SpanningWrapper spanningWrapper;
+
+       public SpillingAdaptiveSpanningRecordDeserializer() {
+               
+               String tempDirString = GlobalConfiguration.getString(
+                               ConfigConstants.TASK_MANAGER_TMP_DIR_KEY,
+                               ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH);
+               String[] directories = tempDirString.split(",|" + 
File.pathSeparator);
+               
+               this.nonSpanningWrapper = new NonSpanningWrapper();
+               this.spanningWrapper = new SpanningWrapper(directories);
+       }
+       
+       @Override
+       public void setNextMemorySegment(MemorySegment segment, int numBytes) 
throws IOException {
+               // check if some spanning record deserialization is pending
+               if (this.spanningWrapper.getNumGatheredBytes() > 0) {
+                       
this.spanningWrapper.addNextChunkFromMemorySegment(segment, numBytes);
+               }
+               else {
+                       
this.nonSpanningWrapper.initializeFromMemorySegment(segment, 0, numBytes);
+               }
+       }
+       
+       @Override
+       public DeserializationResult getNextRecord(T target) throws IOException 
{
+               // always check the non-spanning wrapper first.
+               // this should be the majority of the cases for small records
+               // for large records, this portion of the work is very small in 
comparison anyways
+               
+               int nonSpanningRemaining = this.nonSpanningWrapper.remaining();
+               
+               // check if we can get a full length;
+               if (nonSpanningRemaining >= 4) {
+                       int len = this.nonSpanningWrapper.readInt();
+
+                       if (len <= nonSpanningRemaining - 4) {
+                               // we can get a full record from here
+                               target.read(this.nonSpanningWrapper);
+                               
+                               return (this.nonSpanningWrapper.remaining() == 
0) ?
+                                       
DeserializationResult.LAST_RECORD_FROM_BUFFER :
+                                       
DeserializationResult.INTERMEDIATE_RECORD_FROM_BUFFER;
+                       } else {
+                               // we got the length, but we need the rest from 
the spanning deserializer
+                               // and need to wait for more buffers
+                               
this.spanningWrapper.initializeWithPartialRecord(this.nonSpanningWrapper, len);
+                               this.nonSpanningWrapper.clear();
+                               return DeserializationResult.PARTIAL_RECORD;
+                       }
+               } else if (nonSpanningRemaining > 0) {
+                       // we have an incomplete length
+                       // add our part of the length to the length buffer
+                       
this.spanningWrapper.initializeWithPartialLength(this.nonSpanningWrapper);
+                       this.nonSpanningWrapper.clear();
+                       return DeserializationResult.PARTIAL_RECORD;
+               }
+               
+               // spanning record case
+               if (this.spanningWrapper.hasFullRecord()) {
+                       // get the full record
+                       target.read(this.spanningWrapper.getInputView());
+                       
+                       // move the remainder to the non-spanning wrapper
+                       // this does not copy it, only sets the memory segment
+                       
this.spanningWrapper.moveRemainderToNonSpanningDeserializer(this.nonSpanningWrapper);
+                       this.spanningWrapper.clear();
+                       
+                       return (this.nonSpanningWrapper.remaining() == 0) ?
+                               DeserializationResult.LAST_RECORD_FROM_BUFFER :
+                               
DeserializationResult.INTERMEDIATE_RECORD_FROM_BUFFER;
+               } else {
+                       return DeserializationResult.PARTIAL_RECORD;
+               }
+       }
+
+       @Override
+       public void clear() {
+               this.nonSpanningWrapper.clear();
+               this.spanningWrapper.clear();
+       }
+
+       @Override
+       public boolean hasUnfinishedData() {
+               return this.nonSpanningWrapper.remaining() > 0 || 
this.spanningWrapper.getNumGatheredBytes() > 0;
+       }
+
+       // 
-----------------------------------------------------------------------------------------------------------------
+       
+       private static final class NonSpanningWrapper implements DataInputView {
+               
+               private MemorySegment segment;
+               
+               private int limit;
+               
+               private int position;
+               
+               private byte[] utfByteBuffer; // reusable byte buffer for utf-8 
decoding
+               private char[] utfCharBuffer; // reusable char buffer for utf-8 
decoding
+               
+               int remaining() {
+                       return this.limit - this.position;
+               }
+               
+               void clear() {
+                       this.segment = null;
+                       this.limit = 0;
+                       this.position = 0;
+               }
+               
+               void initializeFromMemorySegment(MemorySegment seg, int 
position, int leftOverLimit) {
+                       this.segment = seg;
+                       this.position = position;
+                       this.limit = leftOverLimit;
+               }
+               
+               // 
-------------------------------------------------------------------------------------------------------------
+               //                                       DataInput specific 
methods
+               // 
-------------------------------------------------------------------------------------------------------------
+               
+               @Override
+               public final void readFully(byte[] b) throws IOException {
+                       readFully(b, 0, b.length);
+               }
+
+               @Override
+               public final void readFully(byte[] b, int off, int len) throws 
IOException {
+                       if (off < 0 || len < 0 || off + len > b.length) {
+                               throw new IndexOutOfBoundsException();
+                       }
+                       
+                       this.segment.get(this.position, b, off, len);
+                       this.position += len;
+               }
+
+               @Override
+               public final boolean readBoolean() throws IOException {
+                       return readByte() == 1;
+               }
+
+               @Override
+               public final byte readByte() throws IOException {
+                       return this.segment.get(this.position++);
+               }
+
+               @Override
+               public final int readUnsignedByte() throws IOException {
+                       return readByte() & 0xff;
+               }
+
+               @Override
+               public final short readShort() throws IOException {
+                       final short v = this.segment.getShort(this.position);
+                       this.position += 2;
+                       return v;
+               }
+
+               @Override
+               public final int readUnsignedShort() throws IOException {
+                       final int v = this.segment.getShort(this.position) & 
0xffff;
+                       this.position += 2;
+                       return v;
+               }
+
+               @Override
+               public final char readChar() throws IOException  {
+                       final char v = this.segment.getChar(this.position);
+                       this.position += 2;
+                       return v;
+               }
+
+               @Override
+               public final int readInt() throws IOException {
+                       final int v = 
this.segment.getIntBigEndian(this.position);
+                       this.position += 4;
+                       return v;
+               }
+
+               @Override
+               public final long readLong() throws IOException {
+                       final long v = 
this.segment.getLongBigEndian(this.position);
+                       this.position += 8;
+                       return v;
+               }
+
+               @Override
+               public final float readFloat() throws IOException {
+                       return Float.intBitsToFloat(readInt());
+               }
+
+               @Override
+               public final double readDouble() throws IOException {
+                       return Double.longBitsToDouble(readLong());
+               }
+
+               @Override
+               public final String readLine() throws IOException {
+                       final StringBuilder bld = new StringBuilder(32);
+                       
+                       try {
+                               int b;
+                               while ((b = readUnsignedByte()) != '\n') {
+                                       if (b != '\r') {
+                                               bld.append((char) b);
+                                       }
+                               }
+                       }
+                       catch (EOFException eofex) {}
+
+                       if (bld.length() == 0) {
+                               return null;
+                       }
+                       
+                       // trim a trailing carriage return
+                       int len = bld.length();
+                       if (len > 0 && bld.charAt(len - 1) == '\r') {
+                               bld.setLength(len - 1);
+                       }
+                       return bld.toString();
+               }
+
+               @Override
+               public final String readUTF() throws IOException {
+                       final int utflen = readUnsignedShort();
+                       
+                       final byte[] bytearr;
+                       final char[] chararr;
+                       
+                       if (this.utfByteBuffer == null || 
this.utfByteBuffer.length < utflen) {
+                               bytearr = new byte[utflen];
+                               this.utfByteBuffer = bytearr;
+                       } else {
+                               bytearr = this.utfByteBuffer;
+                       }
+                       if (this.utfCharBuffer == null || 
this.utfCharBuffer.length < utflen) {
+                               chararr = new char[utflen];
+                               this.utfCharBuffer = chararr;
+                       } else {
+                               chararr = this.utfCharBuffer;
+                       }
+
+                       int c, char2, char3;
+                       int count = 0;
+                       int chararr_count = 0;
+
+                       readFully(bytearr, 0, utflen);
+
+                       while (count < utflen) {
+                               c = (int) bytearr[count] & 0xff;
+                               if (c > 127) {
+                                       break;
+                               }
+                               count++;
+                               chararr[chararr_count++] = (char) c;
+                       }
+
+                       while (count < utflen) {
+                               c = (int) bytearr[count] & 0xff;
+                               switch (c >> 4) {
+                               case 0:
+                               case 1:
+                               case 2:
+                               case 3:
+                               case 4:
+                               case 5:
+                               case 6:
+                               case 7:
+                                       count++;
+                                       chararr[chararr_count++] = (char) c;
+                                       break;
+                               case 12:
+                               case 13:
+                                       count += 2;
+                                       if (count > utflen) {
+                                               throw new 
UTFDataFormatException("malformed input: partial character at end");
+                                       }
+                                       char2 = (int) bytearr[count - 1];
+                                       if ((char2 & 0xC0) != 0x80) {
+                                               throw new 
UTFDataFormatException("malformed input around byte " + count);
+                                       }
+                                       chararr[chararr_count++] = (char) (((c 
& 0x1F) << 6) | (char2 & 0x3F));
+                                       break;
+                               case 14:
+                                       count += 3;
+                                       if (count > utflen) {
+                                               throw new 
UTFDataFormatException("malformed input: partial character at end");
+                                       }
+                                       char2 = (int) bytearr[count - 2];
+                                       char3 = (int) bytearr[count - 1];
+                                       if (((char2 & 0xC0) != 0x80) || ((char3 
& 0xC0) != 0x80)) {
+                                               throw new 
UTFDataFormatException("malformed input around byte " + (count - 1));
+                                       }
+                                       chararr[chararr_count++] = (char) (((c 
& 0x0F) << 12) | ((char2 & 0x3F) << 6) | ((char3 & 0x3F) << 0));
+                                       break;
+                               default:
+                                       throw new 
UTFDataFormatException("malformed input around byte " + count);
+                               }
+                       }
+                       // The number of chars produced may be less than utflen
+                       return new String(chararr, 0, chararr_count);
+               }
+               
+               @Override
+               public final int skipBytes(int n) throws IOException {
+                       if (n < 0) {
+                               throw new IllegalArgumentException();
+                       }
+                       
+                       int toSkip = Math.min(n, remaining());
+                       this.position += toSkip;
+                       return toSkip;
+               }
+
+               @Override
+               public void skipBytesToRead(int numBytes) throws IOException {
+                       int skippedBytes = skipBytes(numBytes);
+
+                       if(skippedBytes < numBytes){
+                               throw new EOFException("Could not skip " + 
numBytes + " bytes.");
+                       }
+               }
+
+               @Override
+               public int read(byte[] b, int off, int len) throws IOException {
+                       if(b == null){
+                               throw new NullPointerException("Byte array b 
cannot be null.");
+                       }
+
+                       if(off < 0){
+                               throw new IllegalArgumentException("The offset 
off cannot be negative.");
+                       }
+
+                       if(len < 0){
+                               throw new IllegalArgumentException("The length 
len cannot be negative.");
+                       }
+
+                       int toRead = Math.min(len, remaining());
+                       this.segment.get(this.position,b,off, toRead);
+                       this.position += toRead;
+
+                       return toRead;
+               }
+
+               @Override
+               public int read(byte[] b) throws IOException {
+                       return read(b, 0, b.length);
+               }
+       }
+
+       // 
-----------------------------------------------------------------------------------------------------------------
+       
+       private static final class SpanningWrapper {
+               
+               private final byte[] initialBuffer = new byte[1024];
+               
+               private final String[] tempDirs;
+               
+               private final Random rnd = new Random();
+
+               private final DataInputDeserializer serializationReadBuffer;
+
+               private final ByteBuffer lengthBuffer;
+               
+               private FileChannel spillingChannel;
+               
+               private byte[] buffer;
+
+               private int recordLength;
+               
+               private int accumulatedRecordBytes;
+
+               private MemorySegment leftOverData;
+
+               private int leftOverStart;
+
+               private int leftOverLimit;
+               
+               private File spillFile;
+               
+               private InputViewDataInputStreamWrapper spillFileReader;
+               
+               public SpanningWrapper(String[] tempDirs) {
+                       this.tempDirs = tempDirs;
+                       
+                       this.lengthBuffer = ByteBuffer.allocate(4);
+                       this.lengthBuffer.order(ByteOrder.BIG_ENDIAN);
+
+                       this.recordLength = -1;
+
+                       this.serializationReadBuffer = new 
DataInputDeserializer();
+                       this.buffer = initialBuffer;
+               }
+               
+               private void initializeWithPartialRecord(NonSpanningWrapper 
partial, int nextRecordLength) throws IOException {
+                       // set the length and copy what is available to the 
buffer
+                       this.recordLength = nextRecordLength;
+                       
+                       final int numBytesChunk = partial.remaining();
+                       
+                       if (nextRecordLength > THRESHOLD_FOR_SPILLING) {
+                               // create a spilling channel and put the data 
there
+                               this.spillingChannel = createSpillingChannel();
+                               
+                               ByteBuffer toWrite = 
partial.segment.wrap(partial.position, numBytesChunk);
+                               this.spillingChannel.write(toWrite);
+                       }
+                       else {
+                               // collect in memory
+                               ensureBufferCapacity(numBytesChunk);
+                               partial.segment.get(partial.position, buffer, 
0, numBytesChunk);
+                       }
+                       
+                       this.accumulatedRecordBytes = numBytesChunk;
+               }
+               
+               private void initializeWithPartialLength(NonSpanningWrapper 
partial) throws IOException {
+                       // copy what we have to the length buffer
+                       partial.segment.get(partial.position, 
this.lengthBuffer, partial.remaining());
+               }
+               
+               private void addNextChunkFromMemorySegment(MemorySegment 
segment, int numBytesInSegment) throws IOException {
+                       int segmentPosition = 0;
+                       
+                       // check where to go. if we have a partial length, we 
need to complete it first
+                       if (this.lengthBuffer.position() > 0) {
+                               int toPut = 
Math.min(this.lengthBuffer.remaining(), numBytesInSegment);
+                               segment.get(0, this.lengthBuffer, toPut);
+                               
+                               // did we complete the length?
+                               if (this.lengthBuffer.hasRemaining()) {
+                                       return;
+                               } else {
+                                       this.recordLength = 
this.lengthBuffer.getInt(0);
+                                       this.lengthBuffer.clear();
+                                       segmentPosition = toPut;
+                                       
+                                       if (this.recordLength > 
THRESHOLD_FOR_SPILLING) {
+                                               this.spillingChannel = 
createSpillingChannel();
+                                       }
+                               }
+                       }
+
+                       // copy as much as we need or can for this next 
spanning record
+                       int needed = this.recordLength - 
this.accumulatedRecordBytes;
+                       int available = numBytesInSegment - segmentPosition;
+                       int toCopy = Math.min(needed, available);
+
+                       if (spillingChannel != null) {
+                               // spill to file
+                               ByteBuffer toWrite = 
segment.wrap(segmentPosition, toCopy);
+                               this.spillingChannel.write(toWrite);
+                       }
+                       else {
+                               ensureBufferCapacity(accumulatedRecordBytes + 
toCopy);
+                               segment.get(segmentPosition, buffer, 
this.accumulatedRecordBytes, toCopy);
+                       }
+                       
+                       this.accumulatedRecordBytes += toCopy;
+                       
+                       if (toCopy < available) {
+                               // there is more data in the segment
+                               this.leftOverData = segment;
+                               this.leftOverStart = segmentPosition + toCopy;
+                               this.leftOverLimit = numBytesInSegment;
+                       }
+                       
+                       if (accumulatedRecordBytes == recordLength) {
+                               // we have the full record
+                               if (spillingChannel == null) {
+                                       
this.serializationReadBuffer.setBuffer(buffer, 0, recordLength);
+                               }
+                               else {
+                                       spillingChannel.close();
+                                       
+                                       DataInputStream inStream = new 
DataInputStream(new BufferedInputStream(new FileInputStream(spillFile), 2 * 
1024 * 1024));
+                                       this.spillFileReader = new 
InputViewDataInputStreamWrapper(inStream);
+                               }
+                       }
+               }
+               
+               private void 
moveRemainderToNonSpanningDeserializer(NonSpanningWrapper deserializer) {
+                       deserializer.clear();
+                       
+                       if (leftOverData != null) {
+                               
deserializer.initializeFromMemorySegment(leftOverData, leftOverStart, 
leftOverLimit);
+                       }
+               }
+               
+               private boolean hasFullRecord() {
+                       return this.recordLength >= 0 && 
this.accumulatedRecordBytes >= this.recordLength;
+               }
+               
+               private int getNumGatheredBytes() {
+                       return this.accumulatedRecordBytes + (this.recordLength 
>= 0 ? 4 : lengthBuffer.position());
+               }
+
+               public void clear() {
+                       this.buffer = initialBuffer;
+                       this.serializationReadBuffer.releaseArrays();
+
+                       this.recordLength = -1;
+                       this.lengthBuffer.clear();
+                       this.leftOverData = null;
+                       this.accumulatedRecordBytes = 0;
+                       
+                       if (spillingChannel != null) {
+                               try {
+                                       spillingChannel.close();
+                               }
+                               catch (Throwable t) {
+                                       // ignore
+                               }
+                               spillingChannel = null;
+                       }
+                       if (spillFileReader != null) {
+                               try {
+                                       spillFileReader.close();
+                               }
+                               catch (Throwable t) {
+                                       // ignore
+                               }
+                               spillFileReader = null;
+                       }
+                       if (spillFile != null) {
+                               spillFile.delete();
+                               spillFile = null;
+                       }
+               }
+               
+               public DataInputView getInputView() {
+                       if (spillFileReader == null) {
+                               return serializationReadBuffer; 
+                       }
+                       else {
+                               return spillFileReader;
+                       }
+               }
+               
+               private void ensureBufferCapacity(int minLength) {
+                       if (buffer.length < minLength) {
+                               byte[] newBuffer = new byte[Math.max(minLength, 
buffer.length * 2)];
+                               System.arraycopy(buffer, 0, newBuffer, 0, 
accumulatedRecordBytes);
+                               buffer = newBuffer;
+                       }
+               }
+               
+               @SuppressWarnings("resource")
+               private FileChannel createSpillingChannel() throws IOException {
+                       if (spillFile != null) {
+                               throw new IllegalStateException("Spilling file 
already exists.");
+                       }
+                       
+                       String directory = 
tempDirs[rnd.nextInt(tempDirs.length)];
+                       spillFile = new File(directory, randomString(rnd) + 
".inputchannel");
+                       
+                       return new RandomAccessFile(spillFile, 
"rw").getChannel();
+               }
+               
+               private static String randomString(Random random) {
+                       final byte[] bytes = new byte[20];
+                       random.nextBytes(bytes);
+                       return StringUtils.byteToHexString(bytes);
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/be8e1f18/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java
index 793a2e3..35b6f0d 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java
@@ -71,7 +71,7 @@ public class DataInputDeserializer implements DataInputView {
                        throw new NullPointerException();
                }
 
-               if (start < 0 || len < 0 || start + len >= buffer.length) {
+               if (start < 0 || len < 0 || start + len > buffer.length) {
                        throw new IllegalArgumentException();
                }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/be8e1f18/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializationTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializationTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializationTest.java
index 1e59463..6ceb05a 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializationTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializationTest.java
@@ -16,20 +16,21 @@
  * limitations under the License.
  */
 
+package org.apache.flink.runtime.io.network.serialization;
 
-package org.apache.flink.runtime.io.network.api.serialization;
-
-import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
 import org.junit.Assert;
 
 import org.apache.flink.core.memory.MemorySegment;
-import org.apache.flink.runtime.io.network.buffer.Buffer;
-import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult;
-import 
org.apache.flink.runtime.io.network.api.serialization.types.SerializationTestType;
-import 
org.apache.flink.runtime.io.network.api.serialization.types.SerializationTestTypeFactory;
-import org.apache.flink.runtime.io.network.api.serialization.types.Util;
+import org.apache.flink.runtime.io.network.Buffer;
+import 
org.apache.flink.runtime.io.network.serialization.AdaptiveSpanningRecordDeserializer;
+import org.apache.flink.runtime.io.network.serialization.RecordDeserializer;
+import org.apache.flink.runtime.io.network.serialization.RecordSerializer;
+import 
org.apache.flink.runtime.io.network.serialization.SpanningRecordSerializer;
+import 
org.apache.flink.runtime.io.network.serialization.RecordDeserializer.DeserializationResult;
+import 
org.apache.flink.runtime.io.network.serialization.types.SerializationTestType;
+import 
org.apache.flink.runtime.io.network.serialization.types.SerializationTestTypeFactory;
+import org.apache.flink.runtime.io.network.serialization.types.Util;
 import org.junit.Test;
-import org.mockito.Mockito;
 
 import java.util.ArrayDeque;
 
@@ -41,8 +42,10 @@ public class SpanningRecordSerializationTest {
                final int NUM_VALUES = 10;
 
                try {
-                       test(Util.randomRecords(NUM_VALUES, 
SerializationTestTypeFactory.INT), SEGMENT_SIZE);
-               } catch (Exception e) {
+                       
testNonSpillingDeserializer(Util.randomRecords(NUM_VALUES, 
SerializationTestTypeFactory.INT), SEGMENT_SIZE);
+                       testSpillingDeserializer(Util.randomRecords(NUM_VALUES, 
SerializationTestTypeFactory.INT), SEGMENT_SIZE);
+               }
+               catch (Exception e) {
                        e.printStackTrace();
                        Assert.fail("Test encountered an unexpected 
exception.");
                }
@@ -54,8 +57,10 @@ public class SpanningRecordSerializationTest {
                final int NUM_VALUES = 64;
 
                try {
-                       test(Util.randomRecords(NUM_VALUES, 
SerializationTestTypeFactory.INT), SEGMENT_SIZE);
-               } catch (Exception e) {
+                       
testNonSpillingDeserializer(Util.randomRecords(NUM_VALUES, 
SerializationTestTypeFactory.INT), SEGMENT_SIZE);
+                       testSpillingDeserializer(Util.randomRecords(NUM_VALUES, 
SerializationTestTypeFactory.INT), SEGMENT_SIZE);
+               }
+               catch (Exception e) {
                        e.printStackTrace();
                        Assert.fail("Test encountered an unexpected 
exception.");
                }
@@ -67,8 +72,10 @@ public class SpanningRecordSerializationTest {
                final int NUM_VALUES = 248;
 
                try {
-                       test(Util.randomRecords(NUM_VALUES, 
SerializationTestTypeFactory.INT), SEGMENT_SIZE);
-               } catch (Exception e) {
+                       
testNonSpillingDeserializer(Util.randomRecords(NUM_VALUES, 
SerializationTestTypeFactory.INT), SEGMENT_SIZE);
+                       testSpillingDeserializer(Util.randomRecords(NUM_VALUES, 
SerializationTestTypeFactory.INT), SEGMENT_SIZE);
+               }
+               catch (Exception e) {
                        e.printStackTrace();
                        Assert.fail("Test encountered an unexpected 
exception.");
                }
@@ -80,8 +87,10 @@ public class SpanningRecordSerializationTest {
                final int NUM_VALUES = 10000;
 
                try {
-                       test(Util.randomRecords(NUM_VALUES), SEGMENT_SIZE);
-               } catch (Exception e) {
+                       
testNonSpillingDeserializer(Util.randomRecords(NUM_VALUES), SEGMENT_SIZE);
+                       
testSpillingDeserializer(Util.randomRecords(NUM_VALUES), SEGMENT_SIZE);
+               }
+               catch (Exception e) {
                        e.printStackTrace();
                        Assert.fail("Test encountered an unexpected 
exception.");
                }
@@ -89,6 +98,20 @@ public class SpanningRecordSerializationTest {
 
        // 
-----------------------------------------------------------------------------------------------------------------
 
+       private void testNonSpillingDeserializer(Util.MockRecords records, int 
segmentSize) throws Exception {
+               RecordSerializer<SerializationTestType> serializer = new 
SpanningRecordSerializer<SerializationTestType>();
+               RecordDeserializer<SerializationTestType> deserializer = new 
AdaptiveSpanningRecordDeserializer<SerializationTestType>();
+               
+               test(records, segmentSize, serializer, deserializer);
+       }
+       
+       private void testSpillingDeserializer(Util.MockRecords records, int 
segmentSize) throws Exception {
+               RecordSerializer<SerializationTestType> serializer = new 
SpanningRecordSerializer<SerializationTestType>();
+               RecordDeserializer<SerializationTestType> deserializer = new 
SpillingAdaptiveSpanningRecordDeserializer<SerializationTestType>();
+               
+               test(records, segmentSize, serializer, deserializer);
+       }
+       
        /**
         * Iterates over the provided records and tests whether {@link 
SpanningRecordSerializer} and {@link AdaptiveSpanningRecordDeserializer}
         * interact as expected.
@@ -98,13 +121,14 @@ public class SpanningRecordSerializationTest {
         * @param records records to test
         * @param segmentSize size for the {@link MemorySegment}
         */
-       private void test (Util.MockRecords records, int segmentSize) throws 
Exception {
+       private void test(Util.MockRecords records, int segmentSize, 
+                       RecordSerializer<SerializationTestType> serializer,
+                       RecordDeserializer<SerializationTestType> deserializer)
+               throws Exception
+       {
                final int SERIALIZATION_OVERHEAD = 4; // length encoding
 
-               final RecordSerializer<SerializationTestType> serializer = new 
SpanningRecordSerializer<SerializationTestType>();
-               final RecordDeserializer<SerializationTestType> deserializer = 
new AdaptiveSpanningRecordDeserializer<SerializationTestType>();
-
-               final Buffer buffer = new Buffer(new MemorySegment(new 
byte[segmentSize]), Mockito.mock(BufferRecycler.class));
+               final Buffer buffer = new Buffer(new MemorySegment(new 
byte[segmentSize]), segmentSize, null);
 
                final ArrayDeque<SerializationTestType> serializedRecords = new 
ArrayDeque<SerializationTestType>();
 

http://git-wip-us.apache.org/repos/asf/flink/blob/be8e1f18/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializerTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializerTest.java
index 568a599..920d683 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializerTest.java
@@ -16,22 +16,19 @@
  * limitations under the License.
  */
 
+package org.apache.flink.runtime.io.network.serialization;
 
-package org.apache.flink.runtime.io.network.api.serialization;
-
-import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
 import org.junit.Assert;
 
 import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.core.memory.MemorySegment;
-import org.apache.flink.runtime.io.network.buffer.Buffer;
-import 
org.apache.flink.runtime.io.network.api.serialization.RecordSerializer.SerializationResult;
-import 
org.apache.flink.runtime.io.network.api.serialization.types.SerializationTestType;
-import 
org.apache.flink.runtime.io.network.api.serialization.types.SerializationTestTypeFactory;
-import org.apache.flink.runtime.io.network.api.serialization.types.Util;
+import org.apache.flink.runtime.io.network.Buffer;
+import 
org.apache.flink.runtime.io.network.serialization.RecordSerializer.SerializationResult;
+import 
org.apache.flink.runtime.io.network.serialization.types.SerializationTestType;
+import 
org.apache.flink.runtime.io.network.serialization.types.SerializationTestTypeFactory;
+import org.apache.flink.runtime.io.network.serialization.types.Util;
 import org.junit.Test;
-import org.mockito.Mockito;
 
 import java.io.IOException;
 import java.util.Random;
@@ -43,7 +40,7 @@ public class SpanningRecordSerializerTest {
                final int SEGMENT_SIZE = 16;
 
                final SpanningRecordSerializer<SerializationTestType> 
serializer = new SpanningRecordSerializer<SerializationTestType>();
-               final Buffer buffer = new Buffer(new MemorySegment(new 
byte[SEGMENT_SIZE]), Mockito.mock(BufferRecycler.class));
+               final Buffer buffer = new Buffer(new MemorySegment(new 
byte[SEGMENT_SIZE]), SEGMENT_SIZE, null);
                final SerializationTestType randomIntRecord = 
Util.randomRecord(SerializationTestTypeFactory.INT);
 
                Assert.assertFalse(serializer.hasData());
@@ -65,10 +62,11 @@ public class SpanningRecordSerializerTest {
 
                        serializer.addRecord(randomIntRecord);
                        Assert.assertTrue(serializer.hasData());
-               } catch (IOException e) {
+               }
+               catch (Exception e) {
                        e.printStackTrace();
+                       Assert.fail(e.getMessage());
                }
-
        }
 
        @Test
@@ -76,7 +74,7 @@ public class SpanningRecordSerializerTest {
                final int SEGMENT_SIZE = 11;
 
                final SpanningRecordSerializer<SerializationTestType> 
serializer = new SpanningRecordSerializer<SerializationTestType>();
-               final Buffer buffer = new Buffer(new MemorySegment(new 
byte[SEGMENT_SIZE]), Mockito.mock(BufferRecycler.class));
+               final Buffer buffer = new Buffer(new MemorySegment(new 
byte[SEGMENT_SIZE]), SEGMENT_SIZE, null);
 
                try {
                        Assert.assertEquals(SerializationResult.FULL_RECORD, 
serializer.setNextBuffer(buffer));
@@ -97,12 +95,10 @@ public class SpanningRecordSerializerTest {
                                }
 
                                @Override
-                               public void write(DataOutputView out) throws 
IOException {
-                               }
+                               public void write(DataOutputView out) {}
 
                                @Override
-                               public void read(DataInputView in) throws 
IOException {
-                               }
+                               public void read(DataInputView in) {}
 
                                @Override
                                public int hashCode() {
@@ -126,8 +122,10 @@ public class SpanningRecordSerializerTest {
 
                        result = serializer.setNextBuffer(buffer);
                        Assert.assertEquals(SerializationResult.FULL_RECORD, 
result);
-               } catch (IOException e) {
+               }
+               catch (Exception e) {
                        e.printStackTrace();
+                       Assert.fail(e.getMessage());
                }
        }
 
@@ -138,7 +136,8 @@ public class SpanningRecordSerializerTest {
 
                try {
                        test(Util.randomRecords(NUM_VALUES, 
SerializationTestTypeFactory.INT), SEGMENT_SIZE);
-               } catch (Exception e) {
+               }
+               catch (Exception e) {
                        e.printStackTrace();
                        Assert.fail("Test encountered an unexpected 
exception.");
                }
@@ -151,7 +150,8 @@ public class SpanningRecordSerializerTest {
 
                try {
                        test(Util.randomRecords(NUM_VALUES, 
SerializationTestTypeFactory.INT), SEGMENT_SIZE);
-               } catch (Exception e) {
+               }
+               catch (Exception e) {
                        e.printStackTrace();
                        Assert.fail("Test encountered an unexpected 
exception.");
                }
@@ -164,7 +164,8 @@ public class SpanningRecordSerializerTest {
 
                try {
                        test(Util.randomRecords(NUM_VALUES, 
SerializationTestTypeFactory.INT), SEGMENT_SIZE);
-               } catch (Exception e) {
+               }
+               catch (Exception e) {
                        e.printStackTrace();
                        Assert.fail("Test encountered an unexpected 
exception.");
                }
@@ -177,7 +178,8 @@ public class SpanningRecordSerializerTest {
 
                try {
                        test(Util.randomRecords(NUM_VALUES), SEGMENT_SIZE);
-               } catch (Exception e) {
+               }
+               catch (Exception e) {
                        e.printStackTrace();
                        Assert.fail("Test encountered an unexpected 
exception.");
                }

http://git-wip-us.apache.org/repos/asf/flink/blob/be8e1f18/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/serialization/LargeRecordsTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/serialization/LargeRecordsTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/serialization/LargeRecordsTest.java
index 6000fee..6c1fd64 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/serialization/LargeRecordsTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/serialization/LargeRecordsTest.java
@@ -129,4 +129,101 @@ public class LargeRecordsTest {
                        fail(e.getMessage());
                }
        }
+       
+       @Test
+       public void testHandleMixedLargeRecordsSpillingAdaptiveSerializer() {
+               try {
+                       final int NUM_RECORDS = 99;
+                       final int SEGMENT_SIZE = 32 * 1024;
+
+                       final RecordSerializer<SerializationTestType> 
serializer = new SpanningRecordSerializer<SerializationTestType>();
+                       final RecordDeserializer<SerializationTestType> 
deserializer = new 
SpillingAdaptiveSpanningRecordDeserializer<SerializationTestType>();
+
+                       final Buffer buffer = new Buffer(new MemorySegment(new 
byte[SEGMENT_SIZE]), SEGMENT_SIZE, null);
+
+                       List<SerializationTestType> originalRecords = new 
ArrayList<SerializationTestType>();
+                       List<SerializationTestType> deserializedRecords = new 
ArrayList<SerializationTestType>();
+                       
+                       LargeObjectType genLarge = new LargeObjectType();
+                       
+                       Random rnd = new Random();
+                       
+                       for (int i = 0; i < NUM_RECORDS; i++) {
+                               if (i % 2 == 0) {
+                                       originalRecords.add(new IntType(42));
+                                       deserializedRecords.add(new IntType());
+                               } else {
+                                       
originalRecords.add(genLarge.getRandom(rnd));
+                                       deserializedRecords.add(new 
LargeObjectType());
+                               }
+                       }
+
+                       // 
-------------------------------------------------------------------------------------------------------------
+
+                       serializer.setNextBuffer(buffer);
+                       
+                       int numRecordsDeserialized = 0;
+                       
+                       for (SerializationTestType record : originalRecords) {
+
+                               // serialize record
+                               if 
(serializer.addRecord(record).isFullBuffer()) {
+
+                                       // buffer is full => move to 
deserializer
+                                       
deserializer.setNextMemorySegment(serializer.getCurrentBuffer().getMemorySegment(),
 SEGMENT_SIZE);
+
+                                       // deserialize records, as many 
complete as there are
+                                       while (numRecordsDeserialized < 
deserializedRecords.size()) {
+                                               SerializationTestType next = 
deserializedRecords.get(numRecordsDeserialized);
+                                       
+                                               if 
(deserializer.getNextRecord(next).isFullRecord()) {
+                                                       
assertEquals(originalRecords.get(numRecordsDeserialized), next);
+                                                       
numRecordsDeserialized++;
+                                               } else {
+                                                       break;
+                                               }
+                                       }
+
+                                       // move buffers as long as necessary 
(for long records)
+                                       while 
(serializer.setNextBuffer(buffer).isFullBuffer()) {
+                                               
deserializer.setNextMemorySegment(serializer.getCurrentBuffer().getMemorySegment(),
 SEGMENT_SIZE);
+                                       }
+                                       
+                                       // deserialize records, as many as 
there are in the last buffer
+                                       while (numRecordsDeserialized < 
deserializedRecords.size()) {
+                                               SerializationTestType next = 
deserializedRecords.get(numRecordsDeserialized);
+                                       
+                                               if 
(deserializer.getNextRecord(next).isFullRecord()) {
+                                                       
assertEquals(originalRecords.get(numRecordsDeserialized), next);
+                                                       
numRecordsDeserialized++;
+                                               } else {
+                                                       break;
+                                               }
+                                       }
+                               }
+                       }
+                       
+                       // move the last (incomplete buffer)
+                       Buffer last = serializer.getCurrentBuffer();
+                       
deserializer.setNextMemorySegment(last.getMemorySegment(), last.size());
+                       serializer.clear();
+                       
+                       // deserialize records, as many as there are in the 
last buffer
+                       while (numRecordsDeserialized < 
deserializedRecords.size()) {
+                               SerializationTestType next = 
deserializedRecords.get(numRecordsDeserialized);
+                       
+                               
assertTrue(deserializer.getNextRecord(next).isFullRecord());
+                               
assertEquals(originalRecords.get(numRecordsDeserialized), next);
+                               numRecordsDeserialized++;
+                       }
+                       
+                       // might be that the last big records has not yet been 
fully moved, and a small one is missing
+                       assertFalse(serializer.hasData());
+                       assertFalse(deserializer.hasUnfinishedData());
+               }
+               catch (Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
 }

Reply via email to