This is an automated email from the ASF dual-hosted git repository.
zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 80fd4dfb [FLINK-30933] Fix missing max watermark when executing join
in iteration body
80fd4dfb is described below
commit 80fd4dfb843aee1d9cfd93130cfff016a9966b7b
Author: Zhipeng Zhang <[email protected]>
AuthorDate: Wed Mar 15 14:36:47 2023 +0800
[FLINK-30933] Fix missing max watermark when executing join in iteration
body
This closes #206.
---
.../flink/iteration/operator/HeadOperator.java | 2 +-
.../flink/iteration/operator/OutputOperator.java | 5 +
.../BoundedPerRoundStreamIterationITCase.java | 115 ++++++++++++++++++++-
3 files changed, 116 insertions(+), 6 deletions(-)
diff --git
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
index bdbe657a..e5238929 100644
---
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
+++
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
@@ -566,7 +566,7 @@ public class HeadOperator extends
AbstractStreamOperator<IterationRecord<?>>
private MailboxExecutorWithYieldTimeout(MailboxExecutor
mailboxExecutor) {
this.mailboxExecutor = mailboxExecutor;
- this.timer = new Timer();
+ this.timer = new Timer(true);
}
@Override
diff --git
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OutputOperator.java
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OutputOperator.java
index a584c5f4..d0e69712 100644
---
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OutputOperator.java
+++
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OutputOperator.java
@@ -19,9 +19,11 @@
package org.apache.flink.iteration.operator;
import org.apache.flink.iteration.IterationRecord;
+import org.apache.flink.iteration.IterationRecord.Type;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.ChainingStrategy;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
/**
@@ -48,6 +50,9 @@ public class OutputOperator<T> extends
AbstractStreamOperator<T>
if (streamRecord.getValue().getType() == IterationRecord.Type.RECORD) {
reusable.replace(streamRecord.getValue().getValue(),
streamRecord.getTimestamp());
output.collect(reusable);
+ } else if (streamRecord.getValue().getType() == Type.EPOCH_WATERMARK
+ && streamRecord.getValue().getEpoch() == Integer.MAX_VALUE) {
+ output.emitWatermark(new Watermark(Long.MAX_VALUE));
}
}
}
diff --git
a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
index 5f453b72..6b79b66c 100644
---
a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
+++
b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
@@ -18,18 +18,28 @@
package org.apache.flink.test.iteration;
+import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
import org.apache.flink.iteration.IterationBodyResult;
import org.apache.flink.iteration.IterationConfig;
import org.apache.flink.iteration.Iterations;
import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIter;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.minicluster.MiniCluster;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.test.iteration.operators.CollectSink;
import org.apache.flink.test.iteration.operators.EpochRecord;
import org.apache.flink.test.iteration.operators.OutputRecord;
@@ -61,14 +71,16 @@ public class BoundedPerRoundStreamIterationITCase extends
TestLogger {
private MiniCluster miniCluster;
- private SharedReference<BlockingQueue<OutputRecord<Integer>>> result;
+ private SharedReference<BlockingQueue<OutputRecord<Integer>>>
collectedOutputRecord;
+ private SharedReference<BlockingQueue<Long>> collectedWatermarks;
@Before
public void setup() throws Exception {
miniCluster = new MiniCluster(createMiniClusterConfiguration(2, 2));
miniCluster.start();
- result = sharedObjects.add(new LinkedBlockingQueue<>());
+ collectedOutputRecord = sharedObjects.add(new LinkedBlockingQueue<>());
+ collectedWatermarks = sharedObjects.add(new LinkedBlockingQueue<>());
}
@After
@@ -80,15 +92,50 @@ public class BoundedPerRoundStreamIterationITCase extends
TestLogger {
@Test
public void testPerRoundIteration() throws Exception {
- JobGraph jobGraph = createPerRoundJobGraph(4, 1000, 5, result);
+ JobGraph jobGraph = createPerRoundJobGraph(4, 1000, 5,
collectedOutputRecord);
miniCluster.executeJobBlocking(jobGraph);
- assertEquals(5, result.get().size());
+ assertEquals(5, collectedOutputRecord.get().size());
Map<Integer, Tuple2<Integer, Integer>> roundsStat =
- computeRoundStat(result.get(), OutputRecord.Event.TERMINATED,
5);
+ computeRoundStat(collectedOutputRecord.get(),
OutputRecord.Event.TERMINATED, 5);
verifyResult(roundsStat, 5, 1, 4 * (0 + 999) * 1000 / 2);
}
+ @Test
+ public void testPerRoundIterationWithJoin() throws Exception {
+ StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setParallelism(4);
+
+ DataStream<Tuple2<Long, Integer>> input1 =
env.fromElements(Tuple2.of(1L, 1));
+
+ DataStream<Tuple2<Long, Long>> input2 = env.fromElements(Tuple2.of(1L,
2L));
+
+ DataStream<Tuple2<Long, Long>> iterationWithJoinResult =
+ Iterations.iterateBoundedStreamsUntilTermination(
+ DataStreamList.of(input1),
+ ReplayableDataStreamList.replay(input2),
+ IterationConfig.newBuilder()
+ .setOperatorLifeCycle(
+
IterationConfig.OperatorLifeCycle.PER_ROUND)
+ .build(),
+ new IterationBodyWithJoin())
+ .get(0);
+ DataStream<Long> watermarks =
+ iterationWithJoinResult.transform(
+ "CollectingWatermark", Types.LONG, new
CollectingWatermark());
+
+ watermarks.addSink(new LongSink(collectedWatermarks));
+
+ JobGraph graph = env.getStreamGraph().getJobGraph();
+ miniCluster.executeJobBlocking(graph);
+
+ assertEquals(env.getParallelism(), collectedWatermarks.get().size());
+ collectedWatermarks
+ .get()
+ .iterator()
+ .forEachRemaining(x -> assertEquals(Long.MAX_VALUE, (long) x));
+ }
+
private static JobGraph createPerRoundJobGraph(
int numSources,
int numRecordsPerSource,
@@ -148,4 +195,62 @@ public class BoundedPerRoundStreamIterationITCase extends
TestLogger {
return env.getStreamGraph().getJobGraph();
}
+
+ private static class IterationBodyWithJoin implements IterationBody {
+ @Override
+ public IterationBodyResult process(
+ DataStreamList variableStreams, DataStreamList dataStreams) {
+ DataStream<Tuple2<Long, Integer>> input1 = variableStreams.get(0);
+ DataStream<Tuple2<Long, Long>> input2 = dataStreams.get(0);
+
+ DataStream<Long> terminationCriteria = input1.flatMap(new
TerminateOnMaxIter(1));
+
+ DataStream<Tuple2<Long, Long>> res =
+ input1.join(input2)
+ .where(x -> x.f0)
+ .equalTo(x -> x.f0)
+ .window(EndOfStreamWindows.get())
+ .apply(
+ new JoinFunction<
+ Tuple2<Long, Integer>,
+ Tuple2<Long, Long>,
+ Tuple2<Long, Long>>() {
+ @Override
+ public Tuple2<Long, Long> join(
+ Tuple2<Long, Integer>
longIntegerTuple2,
+ Tuple2<Long, Long>
longLongTuple2) {
+ return longLongTuple2;
+ }
+ });
+
+ return new IterationBodyResult(
+ DataStreamList.of(input1), DataStreamList.of(res),
terminationCriteria);
+ }
+ }
+
+ private static class LongSink implements SinkFunction<Long> {
+ private final SharedReference<BlockingQueue<Long>> collectedLong;
+
+ public LongSink(SharedReference<BlockingQueue<Long>> collectedLong) {
+ this.collectedLong = collectedLong;
+ }
+
+ @Override
+ public void invoke(Long value, Context context) {
+ collectedLong.get().add(value);
+ }
+ }
+
+ private static class CollectingWatermark extends
AbstractStreamOperator<Long>
+ implements OneInputStreamOperator<Tuple2<Long, Long>, Long> {
+
+ @Override
+ public void processElement(StreamRecord<Tuple2<Long, Long>>
streamRecord) {}
+
+ @Override
+ public void processWatermark(Watermark mark) throws Exception {
+ super.processWatermark(mark);
+ output.collect(new StreamRecord<>(mark.getTimestamp()));
+ }
+ }
}