This is an automated email from the ASF dual-hosted git repository.

zhijiang pushed a commit to branch release-1.11
in repository https://gitbox.apache.org/repos/asf/flink.git

commit aaa3bbc8495474aea03b2e5184e5a15296ed8809
Author: Arvid Heise <[email protected]>
AuthorDate: Wed Jun 3 13:48:49 2020 +0200

    [FLINK-17322][network] Fixes BroadcastRecordWriter overwriting memory 
segments on first finished BufferConsumer.
    
    BroadcastRecordWriter#randomEmit initialized buffer consumers for other 
non-target channels incorrectly leading to separate buffer reference counting 
and subsequently released buffers too early.
    This commit uses the new BufferConsumer#copyWithReaderPosition method to 
copy the buffer while updating the read index to the last committed write index 
of the builder.
---
 .../network/api/writer/BroadcastRecordWriter.java  |  6 +-
 .../runtime/io/network/buffer/BufferBuilder.java   |  6 +-
 .../runtime/io/network/buffer/BufferConsumer.java  | 14 ++++
 .../api/writer/BroadcastRecordWriterTest.java      | 54 +++++++++++++
 .../io/network/api/writer/RecordWriterTest.java    |  8 +-
 .../streaming/runtime/LatencyMarkerITCase.java     | 93 ++++++++++++++++++++++
 6 files changed, 176 insertions(+), 5 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/BroadcastRecordWriter.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/BroadcastRecordWriter.java
index 9964b20..132fefa 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/BroadcastRecordWriter.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/BroadcastRecordWriter.java
@@ -50,6 +50,8 @@ public final class BroadcastRecordWriter<T extends 
IOReadableWritable> extends R
         */
        private boolean randomTriggered;
 
+       private BufferConsumer randomTriggeredConsumer;
+
        BroadcastRecordWriter(
                        ResultPartitionWriter writer,
                        long timeout,
@@ -84,7 +86,7 @@ public final class BroadcastRecordWriter<T extends 
IOReadableWritable> extends R
                if (bufferBuilder != null) {
                        for (int index = 0; index < numberOfChannels; index++) {
                                if (index != targetChannelIndex) {
-                                       
addBufferConsumer(bufferBuilder.createBufferConsumer(), index);
+                                       
addBufferConsumer(randomTriggeredConsumer.copyWithReaderPosition(bufferBuilder.getCommittedBytes()),
 index);
                                }
                        }
                }
@@ -130,7 +132,7 @@ public final class BroadcastRecordWriter<T extends 
IOReadableWritable> extends R
 
                BufferBuilder builder = 
super.requestNewBufferBuilder(targetChannel);
                if (randomTriggered) {
-                       addBufferConsumer(builder.createBufferConsumer(), 
targetChannel);
+                       addBufferConsumer(randomTriggeredConsumer = 
builder.createBufferConsumer(), targetChannel);
                } else {
                        try (BufferConsumer bufferConsumer = 
builder.createBufferConsumer()) {
                                for (int channel = 0; channel < 
numberOfChannels; channel++) {
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferBuilder.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferBuilder.java
index b18569e..7780ba8 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferBuilder.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferBuilder.java
@@ -123,6 +123,10 @@ public class BufferBuilder {
                return getMaxCapacity() - positionMarker.getCached();
        }
 
+       public int getCommittedBytes() {
+               return positionMarker.getCached();
+       }
+
        public int getMaxCapacity() {
                return memorySegment.size();
        }
@@ -167,7 +171,7 @@ public class BufferBuilder {
         *
         * <p>Remember to commit the {@link SettablePositionMarker} to make the 
changes visible.
         */
-       private static class SettablePositionMarker implements PositionMarker {
+       static class SettablePositionMarker implements PositionMarker {
                private volatile int position = 0;
 
                /**
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferConsumer.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferConsumer.java
index 863b231..70c8457 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferConsumer.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferConsumer.java
@@ -25,6 +25,7 @@ import javax.annotation.concurrent.NotThreadSafe;
 
 import java.io.Closeable;
 
+import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
 
@@ -79,6 +80,7 @@ public class BufferConsumer implements Closeable {
        private BufferConsumer(Buffer buffer, BufferBuilder.PositionMarker 
currentWriterPosition, int currentReaderPosition) {
                this.buffer = checkNotNull(buffer);
                this.writerPosition = new 
CachedPositionMarker(checkNotNull(currentWriterPosition));
+               checkArgument(currentReaderPosition <= 
writerPosition.getCached(), "Reader position larger than writer position");
                this.currentReaderPosition = currentReaderPosition;
        }
 
@@ -118,6 +120,18 @@ public class BufferConsumer implements Closeable {
                return new BufferConsumer(buffer.retainBuffer(), 
writerPosition.positionMarker, currentReaderPosition);
        }
 
+       /**
+        * Returns a retained copy with separate indexes and sets the reader 
position to the given value. This allows to
+        * read from the same {@link MemorySegment} twice starting from the 
supplied position.
+        *
+        * @param readerPosition the new reader position. Can be less than the 
{@link #currentReaderPosition}, but may not
+        *                                               exceed the current 
writer's position.
+        * @return a retained copy of self with separate indexes
+        */
+       public BufferConsumer copyWithReaderPosition(int readerPosition) {
+               return new BufferConsumer(buffer.retainBuffer(), 
writerPosition.positionMarker, readerPosition);
+       }
+
        public boolean isBuffer() {
                return buffer.isBuffer();
        }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/BroadcastRecordWriterTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/BroadcastRecordWriterTest.java
index d95b87c..dccfd1d 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/BroadcastRecordWriterTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/BroadcastRecordWriterTest.java
@@ -21,8 +21,10 @@ package org.apache.flink.runtime.io.network.api.writer;
 import org.apache.flink.core.io.IOReadableWritable;
 import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
 import 
org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
 import org.apache.flink.runtime.io.network.util.TestPooledBufferProvider;
+import org.apache.flink.testutils.serialization.types.IntType;
 import org.apache.flink.testutils.serialization.types.SerializationTestType;
 import 
org.apache.flink.testutils.serialization.types.SerializationTestTypeFactory;
 import org.apache.flink.testutils.serialization.types.Util;
@@ -30,7 +32,9 @@ import org.apache.flink.testutils.serialization.types.Util;
 import org.junit.Test;
 
 import java.util.ArrayDeque;
+import java.util.Arrays;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Queue;
 
@@ -109,4 +113,54 @@ public class BroadcastRecordWriterTest extends 
RecordWriterTest {
                                numberOfTotalRecords);
                }
        }
+
+       /**
+        * FLINK-17780: Tests that a shared buffer(or memory segment) of a 
buffer builder is only freed when all consumers
+        * are closed.
+        */
+       @Test
+       public void testRandomEmitAndBufferRecycling() throws Exception {
+               int recordSize = 8;
+
+               final TestPooledBufferProvider bufferProvider = new 
TestPooledBufferProvider(2, 2 * recordSize);
+               final KeepingPartitionWriter partitionWriter = new 
KeepingPartitionWriter(bufferProvider) {
+                       @Override
+                       public int getNumberOfSubpartitions() {
+                               return 2;
+                       }
+               };
+               final BroadcastRecordWriter<SerializationTestType> writer = new 
BroadcastRecordWriter<>(partitionWriter, 0, "test");
+
+               // force materialization of both buffers for easier 
availability tests
+               List<Buffer> buffers = 
Arrays.asList(bufferProvider.requestBuffer(), bufferProvider.requestBuffer());
+               buffers.forEach(Buffer::recycleBuffer);
+               assertEquals(2, bufferProvider.getNumberOfAvailableBuffers());
+
+               // fill first buffer
+               writer.randomEmit(new IntType(1), 0);
+               writer.broadcastEmit(new IntType(2));
+               assertEquals(1, bufferProvider.getNumberOfAvailableBuffers());
+
+               // simulate consumption of first buffer consumer; this should 
not free buffers
+               assertEquals(1, 
partitionWriter.getAddedBufferConsumers(0).size());
+               closeConsumer(partitionWriter, 0, 2 * recordSize);
+               assertEquals(1, bufferProvider.getNumberOfAvailableBuffers());
+
+               // use second buffer
+               writer.broadcastEmit(new IntType(3));
+               assertEquals(0, bufferProvider.getNumberOfAvailableBuffers());
+
+               // fully free first buffer
+               assertEquals(2, 
partitionWriter.getAddedBufferConsumers(1).size());
+               closeConsumer(partitionWriter, 1, recordSize);
+               assertEquals(1, bufferProvider.getNumberOfAvailableBuffers());
+       }
+
+       public void closeConsumer(KeepingPartitionWriter partitionWriter, int 
subpartitionIndex, int expectedSize) {
+               BufferConsumer bufferConsumer = 
partitionWriter.getAddedBufferConsumers(subpartitionIndex).get(0);
+               Buffer buffer = bufferConsumer.build();
+               bufferConsumer.close();
+               assertEquals(expectedSize, buffer.getSize());
+               buffer.recycleBuffer();
+       }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
index 8d7e9ad..84ec497 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
@@ -686,11 +686,11 @@ public class RecordWriterTest {
                }
        }
 
-       private static class KeepingPartitionWriter extends 
MockResultPartitionWriter {
+       static class KeepingPartitionWriter extends MockResultPartitionWriter {
                private final BufferProvider bufferProvider;
                private Map<Integer, List<BufferConsumer>> produced = new 
HashMap<>();
 
-               private KeepingPartitionWriter(BufferProvider bufferProvider) {
+               KeepingPartitionWriter(BufferProvider bufferProvider) {
                        this.bufferProvider = bufferProvider;
                }
 
@@ -712,6 +712,10 @@ public class RecordWriterTest {
                        return true;
                }
 
+               public List<BufferConsumer> getAddedBufferConsumers(int 
subpartitionIndex) {
+                       return produced.get(subpartitionIndex);
+               }
+
                @Override
                public void close() {
                        for (List<BufferConsumer> bufferConsumers : 
produced.values()) {
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/LatencyMarkerITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/LatencyMarkerITCase.java
new file mode 100644
index 0000000..017642c
--- /dev/null
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/LatencyMarkerITCase.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.streaming.runtime;
+
+import org.apache.flink.api.common.JobExecutionResult;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction;
+import 
org.apache.flink.test.checkpointing.utils.MigrationTestUtils.AccumulatorCountingSink;
+import org.apache.flink.util.Collector;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/**
+ * Tests latency marker.
+ */
+public class LatencyMarkerITCase {
+       /**
+        * FLINK-17780: Tests that streams are not corrupted/records lost when 
using latency markers with broadcast.
+        */
+       @Test
+       public void testBroadcast() throws Exception {
+               int inputCount = 100000;
+               int parallelism = 4;
+
+               StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+               env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime);
+               env.setParallelism(parallelism);
+               env.getConfig().setLatencyTrackingInterval(2000);
+               env.setRestartStrategy(RestartStrategies.noRestart());
+
+               List<Integer> broadcastData = IntStream.range(0, 
inputCount).boxed().collect(Collectors.toList());
+               DataStream<Integer> broadcastDataStream = 
env.fromCollection(broadcastData)
+                       .setParallelism(1);
+
+               // broadcast the configurations and create the broadcast state
+
+               DataStream<String> streamWithoutData = 
env.fromCollection(Collections.emptyList(), TypeInformation.of(String.class));
+
+               MapStateDescriptor<String, Integer> stateDescriptor = new 
MapStateDescriptor<>("BroadcastState", BasicTypeInfo.STRING_TYPE_INFO, 
BasicTypeInfo.INT_TYPE_INFO);
+
+               SingleOutputStreamOperator<Integer> processor = 
streamWithoutData
+                       .connect(broadcastDataStream.broadcast(stateDescriptor))
+                       .process(new BroadcastProcessFunction<String, Integer, 
Integer>() {
+                               int expected = 0;
+
+                               public void processElement(String value, 
ReadOnlyContext ctx, Collector<Integer> out) {
+                               }
+
+                               public void processBroadcastElement(Integer 
value, Context ctx, Collector<Integer> out) {
+                                       if (value != expected++) {
+                                               throw new 
AssertionError(String.format("Value was supposed to be: '%s', but was: '%s'", 
expected - 1, value));
+                                       }
+                                       out.collect(value);
+                               }
+                       });
+
+               processor.addSink(new AccumulatorCountingSink<>())
+                       .setParallelism(1);
+
+               JobExecutionResult executionResult = env.execute();
+
+               Integer count = 
executionResult.getAccumulatorResult(AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR);
+               Assert.assertEquals(inputCount * parallelism, count.intValue());
+       }
+}

Reply via email to