This is an automated email from the ASF dual-hosted git repository. zhijiang pushed a commit to branch release-1.11 in repository https://gitbox.apache.org/repos/asf/flink.git
commit aaa3bbc8495474aea03b2e5184e5a15296ed8809 Author: Arvid Heise <[email protected]> AuthorDate: Wed Jun 3 13:48:49 2020 +0200 [FLINK-17322][network] Fixes BroadcastRecordWriter overwriting memory segments on first finished BufferConsumer. BroadcastRecordWriter#randomEmit initialized buffer consumers for other non-target channels incorrectly leading to separate buffer reference counting and subsequently released buffers too early. This commit uses the new BufferConsumer#copyWithReaderPosition method to copy the buffer while updating the read index to the last committed write index of the builder. --- .../network/api/writer/BroadcastRecordWriter.java | 6 +- .../runtime/io/network/buffer/BufferBuilder.java | 6 +- .../runtime/io/network/buffer/BufferConsumer.java | 14 ++++ .../api/writer/BroadcastRecordWriterTest.java | 54 +++++++++++++ .../io/network/api/writer/RecordWriterTest.java | 8 +- .../streaming/runtime/LatencyMarkerITCase.java | 93 ++++++++++++++++++++++ 6 files changed, 176 insertions(+), 5 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/BroadcastRecordWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/BroadcastRecordWriter.java index 9964b20..132fefa 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/BroadcastRecordWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/BroadcastRecordWriter.java @@ -50,6 +50,8 @@ public final class BroadcastRecordWriter<T extends IOReadableWritable> extends R */ private boolean randomTriggered; + private BufferConsumer randomTriggeredConsumer; + BroadcastRecordWriter( ResultPartitionWriter writer, long timeout, @@ -84,7 +86,7 @@ public final class BroadcastRecordWriter<T extends IOReadableWritable> extends R if (bufferBuilder != null) { for (int index = 0; index < numberOfChannels; index++) { if (index != targetChannelIndex) { - addBufferConsumer(bufferBuilder.createBufferConsumer(), index); + addBufferConsumer(randomTriggeredConsumer.copyWithReaderPosition(bufferBuilder.getCommittedBytes()), index); } } } @@ -130,7 +132,7 @@ public final class BroadcastRecordWriter<T extends IOReadableWritable> extends R BufferBuilder builder = super.requestNewBufferBuilder(targetChannel); if (randomTriggered) { - addBufferConsumer(builder.createBufferConsumer(), targetChannel); + addBufferConsumer(randomTriggeredConsumer = builder.createBufferConsumer(), targetChannel); } else { try (BufferConsumer bufferConsumer = builder.createBufferConsumer()) { for (int channel = 0; channel < numberOfChannels; channel++) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferBuilder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferBuilder.java index b18569e..7780ba8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferBuilder.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferBuilder.java @@ -123,6 +123,10 @@ public class BufferBuilder { return getMaxCapacity() - positionMarker.getCached(); } + public int getCommittedBytes() { + return positionMarker.getCached(); + } + public int getMaxCapacity() { return memorySegment.size(); } @@ -167,7 +171,7 @@ public class BufferBuilder { * * <p>Remember to commit the {@link SettablePositionMarker} to make the changes visible. */ - private static class SettablePositionMarker implements PositionMarker { + static class SettablePositionMarker implements PositionMarker { private volatile int position = 0; /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferConsumer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferConsumer.java index 863b231..70c8457 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferConsumer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferConsumer.java @@ -25,6 +25,7 @@ import javax.annotation.concurrent.NotThreadSafe; import java.io.Closeable; +import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; @@ -79,6 +80,7 @@ public class BufferConsumer implements Closeable { private BufferConsumer(Buffer buffer, BufferBuilder.PositionMarker currentWriterPosition, int currentReaderPosition) { this.buffer = checkNotNull(buffer); this.writerPosition = new CachedPositionMarker(checkNotNull(currentWriterPosition)); + checkArgument(currentReaderPosition <= writerPosition.getCached(), "Reader position larger than writer position"); this.currentReaderPosition = currentReaderPosition; } @@ -118,6 +120,18 @@ public class BufferConsumer implements Closeable { return new BufferConsumer(buffer.retainBuffer(), writerPosition.positionMarker, currentReaderPosition); } + /** + * Returns a retained copy with separate indexes and sets the reader position to the given value. This allows to + * read from the same {@link MemorySegment} twice starting from the supplied position. + * + * @param readerPosition the new reader position. Can be less than the {@link #currentReaderPosition}, but may not + * exceed the current writer's position. + * @return a retained copy of self with separate indexes + */ + public BufferConsumer copyWithReaderPosition(int readerPosition) { + return new BufferConsumer(buffer.retainBuffer(), writerPosition.positionMarker, readerPosition); + } + public boolean isBuffer() { return buffer.isBuffer(); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/BroadcastRecordWriterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/BroadcastRecordWriterTest.java index d95b87c..dccfd1d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/BroadcastRecordWriterTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/BroadcastRecordWriterTest.java @@ -21,8 +21,10 @@ package org.apache.flink.runtime.io.network.api.writer; import org.apache.flink.core.io.IOReadableWritable; import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferConsumer; import org.apache.flink.runtime.io.network.util.TestPooledBufferProvider; +import org.apache.flink.testutils.serialization.types.IntType; import org.apache.flink.testutils.serialization.types.SerializationTestType; import org.apache.flink.testutils.serialization.types.SerializationTestTypeFactory; import org.apache.flink.testutils.serialization.types.Util; @@ -30,7 +32,9 @@ import org.apache.flink.testutils.serialization.types.Util; import org.junit.Test; import java.util.ArrayDeque; +import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Queue; @@ -109,4 +113,54 @@ public class BroadcastRecordWriterTest extends RecordWriterTest { numberOfTotalRecords); } } + + /** + * FLINK-17780: Tests that a shared buffer(or memory segment) of a buffer builder is only freed when all consumers + * are closed. + */ + @Test + public void testRandomEmitAndBufferRecycling() throws Exception { + int recordSize = 8; + + final TestPooledBufferProvider bufferProvider = new TestPooledBufferProvider(2, 2 * recordSize); + final KeepingPartitionWriter partitionWriter = new KeepingPartitionWriter(bufferProvider) { + @Override + public int getNumberOfSubpartitions() { + return 2; + } + }; + final BroadcastRecordWriter<SerializationTestType> writer = new BroadcastRecordWriter<>(partitionWriter, 0, "test"); + + // force materialization of both buffers for easier availability tests + List<Buffer> buffers = Arrays.asList(bufferProvider.requestBuffer(), bufferProvider.requestBuffer()); + buffers.forEach(Buffer::recycleBuffer); + assertEquals(2, bufferProvider.getNumberOfAvailableBuffers()); + + // fill first buffer + writer.randomEmit(new IntType(1), 0); + writer.broadcastEmit(new IntType(2)); + assertEquals(1, bufferProvider.getNumberOfAvailableBuffers()); + + // simulate consumption of first buffer consumer; this should not free buffers + assertEquals(1, partitionWriter.getAddedBufferConsumers(0).size()); + closeConsumer(partitionWriter, 0, 2 * recordSize); + assertEquals(1, bufferProvider.getNumberOfAvailableBuffers()); + + // use second buffer + writer.broadcastEmit(new IntType(3)); + assertEquals(0, bufferProvider.getNumberOfAvailableBuffers()); + + // fully free first buffer + assertEquals(2, partitionWriter.getAddedBufferConsumers(1).size()); + closeConsumer(partitionWriter, 1, recordSize); + assertEquals(1, bufferProvider.getNumberOfAvailableBuffers()); + } + + public void closeConsumer(KeepingPartitionWriter partitionWriter, int subpartitionIndex, int expectedSize) { + BufferConsumer bufferConsumer = partitionWriter.getAddedBufferConsumers(subpartitionIndex).get(0); + Buffer buffer = bufferConsumer.build(); + bufferConsumer.close(); + assertEquals(expectedSize, buffer.getSize()); + buffer.recycleBuffer(); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java index 8d7e9ad..84ec497 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java @@ -686,11 +686,11 @@ public class RecordWriterTest { } } - private static class KeepingPartitionWriter extends MockResultPartitionWriter { + static class KeepingPartitionWriter extends MockResultPartitionWriter { private final BufferProvider bufferProvider; private Map<Integer, List<BufferConsumer>> produced = new HashMap<>(); - private KeepingPartitionWriter(BufferProvider bufferProvider) { + KeepingPartitionWriter(BufferProvider bufferProvider) { this.bufferProvider = bufferProvider; } @@ -712,6 +712,10 @@ public class RecordWriterTest { return true; } + public List<BufferConsumer> getAddedBufferConsumers(int subpartitionIndex) { + return produced.get(subpartitionIndex); + } + @Override public void close() { for (List<BufferConsumer> bufferConsumers : produced.values()) { diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/LatencyMarkerITCase.java b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/LatencyMarkerITCase.java new file mode 100644 index 0000000..017642c --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/LatencyMarkerITCase.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.streaming.runtime; + +import org.apache.flink.api.common.JobExecutionResult; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.streaming.api.TimeCharacteristic; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction; +import org.apache.flink.test.checkpointing.utils.MigrationTestUtils.AccumulatorCountingSink; +import org.apache.flink.util.Collector; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** + * Tests latency marker. + */ +public class LatencyMarkerITCase { + /** + * FLINK-17780: Tests that streams are not corrupted/records lost when using latency markers with broadcast. + */ + @Test + public void testBroadcast() throws Exception { + int inputCount = 100000; + int parallelism = 4; + + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); + env.setParallelism(parallelism); + env.getConfig().setLatencyTrackingInterval(2000); + env.setRestartStrategy(RestartStrategies.noRestart()); + + List<Integer> broadcastData = IntStream.range(0, inputCount).boxed().collect(Collectors.toList()); + DataStream<Integer> broadcastDataStream = env.fromCollection(broadcastData) + .setParallelism(1); + + // broadcast the configurations and create the broadcast state + + DataStream<String> streamWithoutData = env.fromCollection(Collections.emptyList(), TypeInformation.of(String.class)); + + MapStateDescriptor<String, Integer> stateDescriptor = new MapStateDescriptor<>("BroadcastState", BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO); + + SingleOutputStreamOperator<Integer> processor = streamWithoutData + .connect(broadcastDataStream.broadcast(stateDescriptor)) + .process(new BroadcastProcessFunction<String, Integer, Integer>() { + int expected = 0; + + public void processElement(String value, ReadOnlyContext ctx, Collector<Integer> out) { + } + + public void processBroadcastElement(Integer value, Context ctx, Collector<Integer> out) { + if (value != expected++) { + throw new AssertionError(String.format("Value was supposed to be: '%s', but was: '%s'", expected - 1, value)); + } + out.collect(value); + } + }); + + processor.addSink(new AccumulatorCountingSink<>()) + .setParallelism(1); + + JobExecutionResult executionResult = env.execute(); + + Integer count = executionResult.getAccumulatorResult(AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR); + Assert.assertEquals(inputCount * parallelism, count.intValue()); + } +}
