[FLINK-2308] [runtime] Give proper error messages in case user-defined serialization is broken and detected in the network stack.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/f5c1768a Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/f5c1768a Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/f5c1768a Branch: refs/heads/master Commit: f5c1768aa730fb6be74c0ebf480675acb1488d4f Parents: 1b97505 Author: Stephan Ewen <[email protected]> Authored: Sun Jun 14 16:39:47 2015 +0200 Committer: Stephan Ewen <[email protected]> Committed: Wed Jul 1 16:11:22 2015 +0200 ---------------------------------------------------------------------- .../api/reader/AbstractRecordReader.java | 2 +- ...llingAdaptiveSpanningRecordDeserializer.java | 32 ++- .../test/misc/CustomSerializationITCase.java | 269 +++++++++++++++++++ 3 files changed, 296 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/f5c1768a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractRecordReader.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractRecordReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractRecordReader.java index bf43c72..56e5d33 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractRecordReader.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractRecordReader.java @@ -83,7 +83,7 @@ abstract class AbstractRecordReader<T extends IOReadableWritable> extends Abstra // sanity check for leftover data in deserializers. events should only come between // records, not in the middle of a fragment if (recordDeserializers[bufferOrEvent.getChannelIndex()].hasUnfinishedData()) { - throw new IllegalStateException( + throw new IOException( "Received an event in channel " + bufferOrEvent.getChannelIndex() + " while still having " + "data from a record. This indicates broken serialization logic. " + "If you are using custom serialization code (Writable or Value types), check their " http://git-wip-us.apache.org/repos/asf/flink/blob/f5c1768a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpillingAdaptiveSpanningRecordDeserializer.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpillingAdaptiveSpanningRecordDeserializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpillingAdaptiveSpanningRecordDeserializer.java index 453d448..6b0d836 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpillingAdaptiveSpanningRecordDeserializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpillingAdaptiveSpanningRecordDeserializer.java @@ -46,8 +46,15 @@ import java.util.Random; */ public class SpillingAdaptiveSpanningRecordDeserializer<T extends IOReadableWritable> implements RecordDeserializer<T> { + private static final String BROKEN_SERIALIZATION_ERROR_MESSAGE = + "Serializer consumed more bytes than the record had. " + + "This indicates broken serialization. If you are using custom serialization types " + + "(Value or Writable), check their serialization methods. If you are using a " + + "Kryo-serialized type, check the corresponding Kryo serializer."; + private static final int THRESHOLD_FOR_SPILLING = 5 * 1024 * 1024; // 5 MiBytes + private final NonSpanningWrapper nonSpanningWrapper; private final SpanningWrapper spanningWrapper; @@ -107,12 +114,25 @@ public class SpillingAdaptiveSpanningRecordDeserializer<T extends IOReadableWrit if (len <= nonSpanningRemaining - 4) { // we can get a full record from here - target.read(this.nonSpanningWrapper); - - return (this.nonSpanningWrapper.remaining() == 0) ? - DeserializationResult.LAST_RECORD_FROM_BUFFER : - DeserializationResult.INTERMEDIATE_RECORD_FROM_BUFFER; - } else { + try { + target.read(this.nonSpanningWrapper); + + int remaining = this.nonSpanningWrapper.remaining(); + if (remaining > 0) { + return DeserializationResult.INTERMEDIATE_RECORD_FROM_BUFFER; + } + else if (remaining == 0) { + return DeserializationResult.LAST_RECORD_FROM_BUFFER; + } + else { + throw new IndexOutOfBoundsException("Remaining = " + remaining); + } + } + catch (IndexOutOfBoundsException e) { + throw new IOException(BROKEN_SERIALIZATION_ERROR_MESSAGE, e); + } + } + else { // we got the length, but we need the rest from the spanning deserializer // and need to wait for more buffers this.spanningWrapper.initializeWithPartialRecord(this.nonSpanningWrapper, len); http://git-wip-us.apache.org/repos/asf/flink/blob/f5c1768a/flink-tests/src/test/java/org/apache/flink/test/misc/CustomSerializationITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/misc/CustomSerializationITCase.java b/flink-tests/src/test/java/org/apache/flink/test/misc/CustomSerializationITCase.java new file mode 100644 index 0000000..4e7da83 --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/misc/CustomSerializationITCase.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.misc; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.io.DiscardingOutputFormat; +import org.apache.flink.client.program.ProgramInvocationException; +import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.test.util.ForkableFlinkMiniCluster; +import org.apache.flink.types.Value; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.IOException; + +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@SuppressWarnings("serial") +public class CustomSerializationITCase { + + private static final int PARLLELISM = 5; + + private static ForkableFlinkMiniCluster cluster; + + @BeforeClass + public static void startCluster() { + try { + Configuration config = new Configuration(); + config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, PARLLELISM); + config.setInteger(ConfigConstants.TASK_MANAGER_MEMORY_SIZE_KEY, 30); + cluster = new ForkableFlinkMiniCluster(config, false); + } + catch (Exception e) { + e.printStackTrace(); + fail("Failed to start test cluster: " + e.getMessage()); + } + } + + @AfterClass + public static void shutdownCluster() { + try { + cluster.shutdown(); + cluster = null; + } + catch (Exception e) { + e.printStackTrace(); + fail("Failed to stop test cluster: " + e.getMessage()); + } + } + + @Test + public void testIncorrectSerializer1() { + try { + ExecutionEnvironment env = + ExecutionEnvironment.createRemoteEnvironment("localhost", cluster.getJobManagerRPCPort()); + + env.setParallelism(PARLLELISM); + env.getConfig().disableSysoutLogging(); + + env + .generateSequence(1, 10 * PARLLELISM) + .map(new MapFunction<Long, ConsumesTooMuch>() { + @Override + public ConsumesTooMuch map(Long value) throws Exception { + return new ConsumesTooMuch(); + } + }) + .rebalance() + .output(new DiscardingOutputFormat<ConsumesTooMuch>()); + + env.execute(); + } + catch (ProgramInvocationException e) { + Throwable rootCause = e.getCause().getCause(); + assertTrue(rootCause instanceof IOException); + assertTrue(rootCause.getMessage().contains("broken serialization")); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testIncorrectSerializer2() { + try { + ExecutionEnvironment env = + ExecutionEnvironment.createRemoteEnvironment("localhost", cluster.getJobManagerRPCPort()); + + env.setParallelism(PARLLELISM); + env.getConfig().disableSysoutLogging(); + + env + .generateSequence(1, 10 * PARLLELISM) + .map(new MapFunction<Long, ConsumesTooMuchSpanning>() { + @Override + public ConsumesTooMuchSpanning map(Long value) throws Exception { + return new ConsumesTooMuchSpanning(); + } + }) + .rebalance() + .output(new DiscardingOutputFormat<ConsumesTooMuchSpanning>()); + + env.execute(); + } + catch (ProgramInvocationException e) { + Throwable rootCause = e.getCause().getCause(); + assertTrue(rootCause instanceof IOException); + assertTrue(rootCause.getMessage().contains("broken serialization")); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testIncorrectSerializer3() { + try { + ExecutionEnvironment env = + ExecutionEnvironment.createRemoteEnvironment("localhost", cluster.getJobManagerRPCPort()); + + env.setParallelism(PARLLELISM); + env.getConfig().disableSysoutLogging(); + + env + .generateSequence(1, 10 * PARLLELISM) + .map(new MapFunction<Long, ConsumesTooLittle>() { + @Override + public ConsumesTooLittle map(Long value) throws Exception { + return new ConsumesTooLittle(); + } + }) + .rebalance() + .output(new DiscardingOutputFormat<ConsumesTooLittle>()); + + env.execute(); + } + catch (ProgramInvocationException e) { + Throwable rootCause = e.getCause().getCause(); + assertTrue(rootCause instanceof IOException); + assertTrue(rootCause.getMessage().contains("broken serialization")); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testIncorrectSerializer4() { + try { + ExecutionEnvironment env = + ExecutionEnvironment.createRemoteEnvironment("localhost", cluster.getJobManagerRPCPort()); + + env.setParallelism(PARLLELISM); + env.getConfig().disableSysoutLogging(); + + env + .generateSequence(1, 10 * PARLLELISM) + .map(new MapFunction<Long, ConsumesTooLittleSpanning>() { + @Override + public ConsumesTooLittleSpanning map(Long value) throws Exception { + return new ConsumesTooLittleSpanning(); + } + }) + .rebalance() + .output(new DiscardingOutputFormat<ConsumesTooLittleSpanning>()); + + env.execute(); + } + catch (ProgramInvocationException e) { + Throwable rootCause = e.getCause().getCause(); + assertTrue(rootCause instanceof IOException); + assertTrue(rootCause.getMessage().contains("broken serialization")); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + // ------------------------------------------------------------------------ + // Custom Data Types with broken Serialization Logic + // ------------------------------------------------------------------------ + + public static class ConsumesTooMuch implements Value { + + @Override + public void write(DataOutputView out) throws IOException { + // write 4 bytes + out.writeInt(42); + } + + @Override + public void read(DataInputView in) throws IOException { + // read 8 bytes + in.readLong(); + } + } + + public static class ConsumesTooMuchSpanning implements Value { + + @Override + public void write(DataOutputView out) throws IOException { + byte[] bytes = new byte[22541]; + out.write(bytes); + } + + @Override + public void read(DataInputView in) throws IOException { + byte[] bytes = new byte[32941]; + in.readFully(bytes); + } + } + + public static class ConsumesTooLittle implements Value { + + @Override + public void write(DataOutputView out) throws IOException { + // write 8 bytes + out.writeLong(42L); + } + + @Override + public void read(DataInputView in) throws IOException { + // read 4 bytes + in.readInt(); + } + } + + public static class ConsumesTooLittleSpanning implements Value { + + @Override + public void write(DataOutputView out) throws IOException { + byte[] bytes = new byte[32941]; + out.write(bytes); + } + + @Override + public void read(DataInputView in) throws IOException { + byte[] bytes = new byte[22541]; + in.readFully(bytes); + } + } +}
