This is an automated email from the ASF dual-hosted git repository.
gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 8269bd9 [FLINK-24279] support withBroadcast in DataStream
8269bd9 is described below
commit 8269bd9fdf3e5744b2d635697db5c705b9e598f5
Author: zhangzp <[email protected]>
AuthorDate: Fri Nov 5 15:33:04 2021 +0800
[FLINK-24279] support withBroadcast in DataStream
This closes #18.
---
.../datacache/nonkeyed/DataCacheReader.java | 6 -
.../datacache/nonkeyed/DataCacheWriter.java | 14 +
flink-ml-lib/pom.xml | 39 ++
.../ml/common/broadcast/BroadcastContext.java | 113 ++++
.../BroadcastStreamingRuntimeContext.java | 96 ++++
.../flink/ml/common/broadcast/BroadcastUtils.java | 194 +++++++
.../operator/AbstractBroadcastWrapperOperator.java | 613 +++++++++++++++++++++
.../BroadcastVariableReceiverOperator.java | 154 ++++++
.../BroadcastVariableReceiverOperatorFactory.java | 54 ++
.../broadcast/operator/BroadcastWrapper.java | 96 ++++
.../operator/OneInputBroadcastWrapperOperator.java | 82 +++
.../operator/TwoInputBroadcastWrapperOperator.java | 114 ++++
.../ml/common/broadcast/typeinfo/CacheElement.java | 79 +++
.../broadcast/typeinfo/CacheElementSerializer.java | 208 +++++++
.../broadcast/typeinfo/CacheElementTypeInfo.java | 103 ++++
.../ml/common/broadcast/BroadcastUtilsTest.java | 176 ++++++
.../apache/flink/ml/common/broadcast/TestSink.java | 86 +++
.../flink/ml/common/broadcast/TestSource.java | 105 ++++
.../BroadcastVariableReceiverOperatorTest.java | 85 +++
.../operator/BroadcastWrapperOperatorFactory.java | 49 ++
.../OneInputBroadcastWrapperOperatorTest.java | 103 ++++
.../common/broadcast/operator/TestOneInputOp.java | 54 ++
.../common/broadcast/operator/TestTwoInputOp.java | 62 +++
.../TwoInputBroadcastWrapperOperatorTest.java | 105 ++++
.../src/test/resources/log4j2-test.properties | 28 +
25 files changed, 2812 insertions(+), 6 deletions(-)
diff --git
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheReader.java
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheReader.java
index 1840104..fc47e9f 100644
---
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheReader.java
+++
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheReader.java
@@ -30,8 +30,6 @@ import java.io.IOException;
import java.util.Iterator;
import java.util.List;
-import static org.apache.flink.util.Preconditions.checkArgument;
-
/** Reads the cached data from a list of paths. */
public class DataCacheReader<T> implements Iterator<T> {
@@ -47,10 +45,6 @@ public class DataCacheReader<T> implements Iterator<T> {
TypeSerializer<T> serializer, FileSystem fileSystem, List<Segment>
segments)
throws IOException {
- for (Segment segment : segments) {
- checkArgument(segment.getCount() > 0, "Do not support empty
segment");
- }
-
this.serializer = serializer;
this.fileSystem = fileSystem;
this.segments = segments;
diff --git
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java
index 35256eb..ec8d73d 100644
---
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java
+++
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java
@@ -26,6 +26,8 @@ import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.util.function.SupplierWithException;
+import javax.annotation.Nullable;
+
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
@@ -49,11 +51,23 @@ public class DataCacheWriter<T> {
FileSystem fileSystem,
SupplierWithException<Path, IOException> pathGenerator)
throws IOException {
+ this(serializer, fileSystem, pathGenerator, null);
+ }
+
+ public DataCacheWriter(
+ TypeSerializer<T> serializer,
+ FileSystem fileSystem,
+ SupplierWithException<Path, IOException> pathGenerator,
+ @Nullable List<Segment> writtenSegments)
+ throws IOException {
this.serializer = serializer;
this.fileSystem = fileSystem;
this.pathGenerator = pathGenerator;
this.finishSegments = new ArrayList<>();
+ if (null != writtenSegments) {
+ finishSegments.addAll(writtenSegments);
+ }
this.currentSegment = new SegmentWriter(pathGenerator.get());
}
diff --git a/flink-ml-lib/pom.xml b/flink-ml-lib/pom.xml
index 17d3e89..bee7f25 100644
--- a/flink-ml-lib/pom.xml
+++ b/flink-ml-lib/pom.xml
@@ -37,6 +37,12 @@ under the License.
<scope>provided</scope>
</dependency>
<dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-ml-iteration</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-table-api-java</artifactId>
<version>${flink.version}</version>
@@ -65,6 +71,39 @@ under the License.
<artifactId>core</artifactId>
<version>1.1.2</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-clients_${scala.binary.version}</artifactId>
+ <version>${flink.version}</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-streaming-java_${scala.binary.version}</artifactId>
+ <version>${flink.version}</version>
+ <scope>test</scope>
+ <type>test-jar</type>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-runtime</artifactId>
+ <version>${flink.version}</version>
+ <scope>test</scope>
+ <type>test-jar</type>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-test-utils-junit</artifactId>
+ <version>${flink.version}</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-core</artifactId>
+ <version>2.21.0</version>
+ <type>jar</type>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
new file mode 100644
index 0000000..ac6950b
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.operators.MailboxExecutor;
+import org.apache.flink.api.java.tuple.Tuple2;
+
+import javax.annotation.Nullable;
+
+import java.util.List;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * Context to hold the broadcast variables and provides some utility function
for accessing
+ * broadcast variables.
+ */
+public class BroadcastContext {
+
+ /**
+ * stores broadcast data streams in a map. The key is
broadcastName-partitionId and the value is
+ * {@link BroadcastItem}.
+ */
+ private static final ConcurrentHashMap<String, BroadcastItem>
BROADCAST_VARIABLES =
+ new ConcurrentHashMap<>();
+
+ @VisibleForTesting
+ public static void putBroadcastVariable(String key, Tuple2<Boolean,
List<?>> variable) {
+ BROADCAST_VARIABLES.compute(
+ key,
+ (k, v) ->
+ null == v
+ ? new BroadcastItem(variable.f0, variable.f1,
null)
+ : new BroadcastItem(variable.f0, variable.f1,
v.mailboxExecutor));
+ }
+
+ @VisibleForTesting
+ public static void putMailBoxExecutor(String key, MailboxExecutor
mailboxExecutor) {
+ BROADCAST_VARIABLES.compute(
+ key,
+ (k, v) ->
+ null == v
+ ? new BroadcastItem(false, null,
mailboxExecutor)
+ : new BroadcastItem(v.cacheReady, v.cacheList,
mailboxExecutor));
+ }
+
+ @VisibleForTesting
+ @SuppressWarnings({"unchecked"})
+ public static <T> List<T> getBroadcastVariable(String key) {
+ return (List<T>) BROADCAST_VARIABLES.get(key).cacheList;
+ }
+
+ @VisibleForTesting
+ public static void remove(String key) {
+ BROADCAST_VARIABLES.remove(key);
+ }
+
+ @VisibleForTesting
+ public static void markCacheFinished(String key) {
+ BROADCAST_VARIABLES.computeIfPresent(
+ key,
+ (k, v) -> {
+ // sends an dummy email to avoid possible stuck.
+ if (null != v.mailboxExecutor) {
+ v.mailboxExecutor.execute(() -> {}, "empty mail");
+ }
+ return new BroadcastItem(true, v.cacheList,
v.mailboxExecutor);
+ });
+ }
+
+ @VisibleForTesting
+ public static boolean isCacheFinished(String key) {
+ return BROADCAST_VARIABLES.get(key).cacheReady;
+ }
+
+ /** Utility class to organize broadcast variables. */
+ private static class BroadcastItem {
+
+ /** whether this broadcast variable is ready to be consumed. */
+ private boolean cacheReady;
+
+ /** the cached list. */
+ private List<?> cacheList;
+
+ /** the mailboxExecutor of the consumer, used to avoid the possible
stuck of consumer. */
+ private MailboxExecutor mailboxExecutor;
+
+ BroadcastItem(
+ boolean cacheReady,
+ @Nullable List<?> cacheList,
+ @Nullable MailboxExecutor mailboxExecutor) {
+ this.cacheReady = cacheReady;
+ this.cacheList = cacheList;
+ this.mailboxExecutor = mailboxExecutor;
+ }
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastStreamingRuntimeContext.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastStreamingRuntimeContext.java
new file mode 100644
index 0000000..809fdac
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastStreamingRuntimeContext.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.accumulators.Accumulator;
+import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
+import org.apache.flink.api.common.state.KeyedStateStore;
+import org.apache.flink.metrics.groups.OperatorMetricGroup;
+import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+
+import javax.annotation.Nullable;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An subclass of {@link StreamingRuntimeContext} that provides accessibility
of broadcast
+ * variables.
+ */
+public class BroadcastStreamingRuntimeContext extends StreamingRuntimeContext {
+
+ Map<String, List<?>> broadcastVariables = new HashMap<>();
+
+ public BroadcastStreamingRuntimeContext(
+ Environment env,
+ Map<String, Accumulator<?, ?>> accumulators,
+ OperatorMetricGroup operatorMetricGroup,
+ OperatorID operatorID,
+ ProcessingTimeService processingTimeService,
+ @Nullable KeyedStateStore keyedStateStore,
+ ExternalResourceInfoProvider externalResourceInfoProvider) {
+ super(
+ env,
+ accumulators,
+ operatorMetricGroup,
+ operatorID,
+ processingTimeService,
+ keyedStateStore,
+ externalResourceInfoProvider);
+ }
+
+ @Override
+ public boolean hasBroadcastVariable(String name) {
+ return broadcastVariables.containsKey(name);
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public <RT> List<RT> getBroadcastVariable(String name) {
+ if (broadcastVariables.containsKey(name)) {
+ return (List<RT>) broadcastVariables.get(name);
+ } else {
+ throw new RuntimeException(
+ "Cannot get broadcast variables before processing
elements.");
+ }
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public <T, C> C getBroadcastVariableWithInitializer(
+ String name, BroadcastVariableInitializer<T, C> initializer) {
+ if (broadcastVariables.containsKey(name)) {
+ return initializer.initializeBroadcastVariable((List<T>)
broadcastVariables.get(name));
+ } else {
+ throw new RuntimeException(
+ "Cannot get broadcast variables before processing
elements.");
+ }
+ }
+
+ @Internal
+ public void setBroadcastVariable(String name, List<?> broadcastVariable) {
+ broadcastVariables.put(name, broadcastVariable);
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
new file mode 100644
index 0000000..bb12647
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.compile.DraftExecutionEnvironment;
+import
org.apache.flink.ml.common.broadcast.operator.BroadcastVariableReceiverOperatorFactory;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.ChainingStrategy;
+import
org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.streaming.api.transformations.PhysicalTransformation;
+import org.apache.flink.util.AbstractID;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.function.Function;
+
+/** Utility class to support withBroadcast in DataStream. */
+public class BroadcastUtils {
+ /**
+ * supports withBroadcastStream in DataStream API. Broadcast data streams
are available at all
+ * parallel instances of an operator that extends {@code
+ * org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator<OUT,
? extends
+ * org.apache.flink.api.common.functions.RichFunction>}. Users can access
the broadcast
+ * variables by {@code
RichFunction.getRuntimeContext().getBroadcastVariable(...)} or {@code
+ * RichFunction.getRuntimeContext().hasBroadcastVariable(...)} or {@code
+ *
RichFunction.getRuntimeContext().getBroadcastVariableWithInitializer(...)}.
+ *
+ * <p>In detail, the broadcast input data streams will be consumed first
and further consumed by
+ * non-broadcast inputs. For now the non-broadcast input are cached by
default to avoid the
+ * possible deadlocks.
+ *
+ * @param inputList non-broadcast input list.
+ * @param bcStreams map of the broadcast data streams, where the key is
the name and the value
+ * is the corresponding data stream.
+ * @param userDefinedFunction the user defined logic in which users can
access the broadcast
+ * data streams and produce the output data stream. Note that users
can add only one
+ * operator in this function, otherwise it raises an exception.
+ * @return the output data stream.
+ */
+ @PublicEvolving
+ public static <OUT> DataStream<OUT> withBroadcastStream(
+ List<DataStream<?>> inputList,
+ Map<String, DataStream<?>> bcStreams,
+ Function<List<DataStream<?>>, DataStream<OUT>>
userDefinedFunction) {
+ Preconditions.checkArgument(inputList.size() > 0);
+
+ StreamExecutionEnvironment env =
inputList.get(0).getExecutionEnvironment();
+ String[] broadcastNames = new String[bcStreams.size()];
+ DataStream<?>[] broadcastInputs = new DataStream[bcStreams.size()];
+ TypeInformation<?>[] broadcastInTypes = new
TypeInformation[bcStreams.size()];
+ int idx = 0;
+ final String broadcastId = new AbstractID().toHexString();
+ for (String name : bcStreams.keySet()) {
+ broadcastNames[idx] = broadcastId + "-" + name;
+ broadcastInputs[idx] = bcStreams.get(name);
+ broadcastInTypes[idx] = broadcastInputs[idx].getType();
+ idx++;
+ }
+
+ DataStream<OUT> resultStream =
+ getResultStream(env, inputList, broadcastNames,
userDefinedFunction);
+ TypeInformation<OUT> outType = resultStream.getType();
+ final String coLocationKey = "broadcast-co-location-" +
UUID.randomUUID();
+ DataStream<OUT> cachedBroadcastInputs =
+ cacheBroadcastVariables(
+ env,
+ broadcastNames,
+ broadcastInputs,
+ broadcastInTypes,
+ resultStream.getParallelism(),
+ outType);
+
+ boolean canCoLocate =
+ cachedBroadcastInputs.getTransformation() instanceof
PhysicalTransformation
+ && resultStream.getTransformation() instanceof
PhysicalTransformation;
+ if (canCoLocate) {
+ ((PhysicalTransformation<?>)
cachedBroadcastInputs.getTransformation())
+ .setChainingStrategy(ChainingStrategy.HEAD);
+ ((PhysicalTransformation<?>) resultStream.getTransformation())
+ .setChainingStrategy(ChainingStrategy.HEAD);
+ } else {
+ throw new UnsupportedOperationException(
+ "cannot set chaining strategy on "
+ + cachedBroadcastInputs.getTransformation()
+ + " and "
+ + resultStream.getTransformation()
+ + ".");
+ }
+
cachedBroadcastInputs.getTransformation().setCoLocationGroupKey(coLocationKey);
+ resultStream.getTransformation().setCoLocationGroupKey(coLocationKey);
+
+ return cachedBroadcastInputs.union(resultStream);
+ }
+
+ /**
+ * caches all broadcast iput data streams in static variables and returns
the result multi-input
+ * stream operator. The result multi-input stream operator emits nothing
and the only
+ * functionality of this operator is to cache all the input records in
${@link
+ * BroadcastContext}.
+ *
+ * @param env execution environment.
+ * @param broadcastInputNames names of the broadcast input data streams.
+ * @param broadcastInputs list of the broadcast data streams.
+ * @param broadcastInTypes output types of the broadcast input data
streams.
+ * @param parallelism parallelism.
+ * @param outType output type.
+ * @param <OUT> output type.
+ * @return the result multi-input stream operator.
+ */
+ private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+ StreamExecutionEnvironment env,
+ String[] broadcastInputNames,
+ DataStream<?>[] broadcastInputs,
+ TypeInformation<?>[] broadcastInTypes,
+ int parallelism,
+ TypeInformation<OUT> outType) {
+ MultipleInputTransformation<OUT> transformation =
+ new MultipleInputTransformation<>(
+ "broadcastInputs",
+ new BroadcastVariableReceiverOperatorFactory<>(
+ broadcastInputNames, broadcastInTypes),
+ outType,
+ parallelism);
+ for (DataStream<?> dataStream : broadcastInputs) {
+
transformation.addInput(dataStream.broadcast().getTransformation());
+ }
+ env.addOperator(transformation);
+ return new MultipleConnectedStreams(env).transform(transformation);
+ }
+
+ /**
+ * uses {@link DraftExecutionEnvironment} to execute the
userDefinedFunction and returns the
+ * resultStream.
+ *
+ * @param env execution environment.
+ * @param inputList non-broadcast input list.
+ * @param broadcastStreamNames names of the broadcast data streams.
+ * @param graphBuilder user-defined logic.
+ * @param <OUT> output type of the result stream.
+ * @return the result stream by applying user-defined logic on the input
list.
+ */
+ private static <OUT> DataStream<OUT> getResultStream(
+ StreamExecutionEnvironment env,
+ List<DataStream<?>> inputList,
+ String[] broadcastStreamNames,
+ Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+ TypeInformation<?>[] inTypes = new TypeInformation[inputList.size()];
+ for (int i = 0; i < inputList.size(); i++) {
+ inTypes[i] = inputList.get(i).getType();
+ }
+ // do not block all non-broadcast input edges by default.
+ boolean[] isBlocked = new boolean[inputList.size()];
+ Arrays.fill(isBlocked, false);
+ DraftExecutionEnvironment draftEnv =
+ new DraftExecutionEnvironment(
+ env, new BroadcastWrapper<>(broadcastStreamNames,
inTypes, isBlocked));
+
+ List<DataStream<?>> draftSources = new ArrayList<>();
+ for (DataStream<?> dataStream : inputList) {
+ draftSources.add(draftEnv.addDraftSource(dataStream,
dataStream.getType()));
+ }
+ DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+ Preconditions.checkState(
+ draftEnv.getStreamGraph(false).getStreamNodes().size() == 1 +
inputList.size(),
+ "cannot add more than one operator in withBroadcastStream's
lambda function.");
+ draftEnv.copyToActualEnvironment();
+ return draftEnv.getActualStream(draftOutStream.getId());
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java
new file mode 100644
index 0000000..200ad94
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java
@@ -0,0 +1,613 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.functions.RichFunction;
+import org.apache.flink.api.common.operators.MailboxExecutor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.ManagedMemoryUseCase;
+import org.apache.flink.iteration.datacache.nonkeyed.DataCacheReader;
+import org.apache.flink.iteration.datacache.nonkeyed.DataCacheSnapshot;
+import org.apache.flink.iteration.datacache.nonkeyed.DataCacheWriter;
+import org.apache.flink.iteration.datacache.nonkeyed.Segment;
+import org.apache.flink.iteration.operator.OperatorUtils;
+import org.apache.flink.iteration.proxy.state.ProxyStreamOperatorStateContext;
+import org.apache.flink.metrics.groups.OperatorMetricGroup;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.ml.common.broadcast.BroadcastStreamingRuntimeContext;
+import org.apache.flink.ml.common.broadcast.typeinfo.CacheElement;
+import org.apache.flink.ml.common.broadcast.typeinfo.CacheElementTypeInfo;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.metrics.groups.InternalOperatorIOMetricGroup;
+import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.util.NonClosingInputStreamDecorator;
+import org.apache.flink.runtime.util.NonClosingOutpusStreamDecorator;
+import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
+import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
+import org.apache.flink.streaming.api.operators.Output;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.StreamOperatorStateContext;
+import org.apache.flink.streaming.api.operators.StreamOperatorStateHandler;
+import
org.apache.flink.streaming.api.operators.StreamOperatorStateHandler.CheckpointedStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.streaming.runtime.tasks.mailbox.TaskMailbox;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.ThrowingConsumer;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.InputStream;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+
+/** Base class for the broadcast wrapper operators. */
+public abstract class AbstractBroadcastWrapperOperator<T, S extends
StreamOperator<T>>
+ implements StreamOperator<T>,
StreamOperatorStateHandler.CheckpointedStreamOperator {
+
+ private static final Logger LOG =
+ LoggerFactory.getLogger(AbstractBroadcastWrapperOperator.class);
+
+ protected final StreamOperatorParameters<T> parameters;
+
+ protected final StreamConfig streamConfig;
+
+ protected final StreamTask<?, ?> containingTask;
+
+ protected final Output<StreamRecord<T>> output;
+
+ protected final StreamOperatorFactory<T> operatorFactory;
+
+ protected final OperatorMetricGroup metrics;
+
+ protected final S wrappedOperator;
+
+ protected transient StreamOperatorStateHandler stateHandler;
+
+ protected transient InternalTimeServiceManager<?> timeServiceManager;
+
+ protected final MailboxExecutor mailboxExecutor;
+
+ /** variables specific for withBroadcast functionality. */
+ protected final String[] broadcastStreamNames;
+
+ /**
+ * whether each input is blocked. Inputs with broadcast variables can only
process their input
+ * records after broadcast variables are ready. One input is non-blocked
if it can consume its
+ * inputs (by caching) when broadcast variables are not ready. Otherwise
it has to block the
+ * processing and wait until the broadcast variables are ready to be
accessed.
+ */
+ protected final boolean[] isBlocked;
+
+ /** type information of each input. */
+ protected final TypeInformation<?>[] inTypes;
+
+ /** whether all broadcast variables of this operator are ready. */
+ protected boolean broadcastVariablesReady;
+
+ /** index of this subtask. */
+ protected final transient int indexOfSubtask;
+
+ /** number of the inputs of this operator. */
+ protected final int numInputs;
+
+ /** runtimeContext of the rich function in wrapped operator. */
+ BroadcastStreamingRuntimeContext wrappedOperatorRuntimeContext;
+
+ /**
+ * path of the file used to stored the cached records. It could be local
file system or remote
+ * file system.
+ */
+ private Path basePath;
+
+ /** DataCacheWriter for each input. */
+ @SuppressWarnings("rawtypes")
+ protected DataCacheWriter[] dataCacheWriters;
+
+ /** whether each input has pending elements. */
+ protected boolean[] hasPendingElements;
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ AbstractBroadcastWrapperOperator(
+ StreamOperatorParameters<T> parameters,
+ StreamOperatorFactory<T> operatorFactory,
+ String[] broadcastStreamNames,
+ TypeInformation<?>[] inTypes,
+ boolean[] isBlocked) {
+ this.parameters = Objects.requireNonNull(parameters);
+ this.streamConfig =
Objects.requireNonNull(parameters.getStreamConfig());
+ this.containingTask =
Objects.requireNonNull(parameters.getContainingTask());
+ this.output = Objects.requireNonNull(parameters.getOutput());
+ this.operatorFactory = Objects.requireNonNull(operatorFactory);
+ this.metrics =
createOperatorMetricGroup(containingTask.getEnvironment(), streamConfig);
+ this.wrappedOperator =
+ (S)
+ StreamOperatorFactoryUtil.<T, S>createOperator(
+ operatorFactory,
+ (StreamTask) containingTask,
+ streamConfig,
+ output,
+
parameters.getOperatorEventDispatcher())
+ .f0;
+
+ boolean hasRichFunction =
+ wrappedOperator instanceof AbstractUdfStreamOperator
+ && ((AbstractUdfStreamOperator)
wrappedOperator).getUserFunction()
+ instanceof RichFunction;
+
+ if (hasRichFunction) {
+ wrappedOperatorRuntimeContext =
+ new BroadcastStreamingRuntimeContext(
+ containingTask.getEnvironment(),
+
containingTask.getEnvironment().getAccumulatorRegistry().getUserMap(),
+ wrappedOperator.getMetricGroup(),
+ wrappedOperator.getOperatorID(),
+ ((AbstractUdfStreamOperator) wrappedOperator)
+ .getProcessingTimeService(),
+ null,
+
containingTask.getEnvironment().getExternalResourceInfoProvider());
+
+ ((RichFunction) ((AbstractUdfStreamOperator)
wrappedOperator).getUserFunction())
+ .setRuntimeContext(wrappedOperatorRuntimeContext);
+ } else {
+ throw new RuntimeException(
+ "The operator is not a instance of "
+ + AbstractUdfStreamOperator.class.getSimpleName()
+ + " that contains a "
+ + RichFunction.class.getSimpleName());
+ }
+
+ this.mailboxExecutor =
+
containingTask.getMailboxExecutorFactory().createExecutor(TaskMailbox.MIN_PRIORITY);
+ // variables specific for withBroadcast functionality.
+ this.broadcastStreamNames = broadcastStreamNames;
+ this.isBlocked = isBlocked;
+ this.inTypes = inTypes;
+ this.broadcastVariablesReady = false;
+ this.indexOfSubtask = containingTask.getIndexInSubtaskGroup();
+ this.numInputs = inTypes.length;
+
+ // puts in mailboxExecutor
+ for (String name : broadcastStreamNames) {
+ BroadcastContext.putMailBoxExecutor(name + "-" + indexOfSubtask,
mailboxExecutor);
+ }
+
+ basePath =
+ OperatorUtils.getDataCachePath(
+
containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(),
+ containingTask
+ .getEnvironment()
+ .getIOManager()
+ .getSpillingDirectoriesPaths());
+ dataCacheWriters = new DataCacheWriter[numInputs];
+ hasPendingElements = new boolean[numInputs];
+ Arrays.fill(hasPendingElements, true);
+ }
+
+ /**
+ * checks whether all of broadcast variables are ready. Besides it
maintains a state
+ * {broadcastVariablesReady} to avoiding invoking {@code
BroadcastContext.isCacheFinished(...)}
+ * repeatedly. Finally, it sets broadcast variables for
{wrappedOperatorRuntimeContext} if the
+ * broadcast variables are ready.
+ *
+ * @return true if all broadcast variables are ready, false otherwise.
+ */
+ protected boolean areBroadcastVariablesReady() {
+ if (broadcastVariablesReady) {
+ return true;
+ }
+ for (String name : broadcastStreamNames) {
+ if (!BroadcastContext.isCacheFinished(name + "-" +
indexOfSubtask)) {
+ return false;
+ } else {
+ String key = name + "-" + indexOfSubtask;
+ String userKey = name.substring(name.indexOf('-') + 1);
+ wrappedOperatorRuntimeContext.setBroadcastVariable(
+ userKey, BroadcastContext.getBroadcastVariable(key));
+ }
+ }
+ broadcastVariablesReady = true;
+ return true;
+ }
+
+ private OperatorMetricGroup createOperatorMetricGroup(
+ Environment environment, StreamConfig streamConfig) {
+ try {
+ OperatorMetricGroup operatorMetricGroup =
+ environment
+ .getMetricGroup()
+ .getOrAddOperator(
+ streamConfig.getOperatorID(),
streamConfig.getOperatorName());
+ if (streamConfig.isChainEnd()) {
+ ((InternalOperatorIOMetricGroup)
operatorMetricGroup.getIOMetricGroup())
+ .reuseOutputMetricsForTask();
+ }
+ return operatorMetricGroup;
+ } catch (Exception e) {
+ LOG.warn("An error occurred while instantiating task metrics.", e);
+ return
UnregisteredMetricGroups.createUnregisteredOperatorMetricGroup();
+ }
+ }
+
+ /**
+ * extracts common processing logic in subclasses' processing elements.
+ *
+ * @param streamRecord the input record.
+ * @param inputIndex input id, starts from zero.
+ * @param elementConsumer the consumer function of StreamRecord, i.e.,
+ * operator.processElement(...).
+ * @param watermarkConsumer the consumer function of WaterMark, i.e.,
+ * operator.processWatermark(...).
+ * @throws Exception possible exception.
+ */
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ protected void processElementX(
+ StreamRecord streamRecord,
+ int inputIndex,
+ ThrowingConsumer<StreamRecord, Exception> elementConsumer,
+ ThrowingConsumer<Watermark, Exception> watermarkConsumer)
+ throws Exception {
+ if (!isBlocked[inputIndex]) {
+ if (areBroadcastVariablesReady()) {
+ if (hasPendingElements[inputIndex]) {
+ processPendingElementsAndWatermarks(
+ inputIndex, elementConsumer, watermarkConsumer);
+ hasPendingElements[inputIndex] = false;
+ }
+ elementConsumer.accept(streamRecord);
+
+ } else {
+ dataCacheWriters[inputIndex].addRecord(
+ CacheElement.newRecord(streamRecord.getValue()));
+ }
+
+ } else {
+ while (!areBroadcastVariablesReady()) {
+ mailboxExecutor.yield();
+ }
+ elementConsumer.accept(streamRecord);
+ }
+ }
+
+ /**
+ * extracts common processing logic in subclasses' processing watermarks.
+ *
+ * @param watermark the input watermark.
+ * @param inputIndex input id, starts from zero.
+ * @param elementConsumer the consumer function of StreamRecord, i.e.,
+ * operator.processElement(...).
+ * @param watermarkConsumer the consumer function of WaterMark, i.e.,
+ * operator.processWatermark(...).
+ * @throws Exception possible exception.
+ */
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ protected void processWatermarkX(
+ Watermark watermark,
+ int inputIndex,
+ ThrowingConsumer<StreamRecord, Exception> elementConsumer,
+ ThrowingConsumer<Watermark, Exception> watermarkConsumer)
+ throws Exception {
+ if (!isBlocked[inputIndex]) {
+ if (areBroadcastVariablesReady()) {
+ if (hasPendingElements[inputIndex]) {
+ processPendingElementsAndWatermarks(
+ inputIndex, elementConsumer, watermarkConsumer);
+ hasPendingElements[inputIndex] = false;
+ }
+ watermarkConsumer.accept(watermark);
+
+ } else {
+ dataCacheWriters[inputIndex].addRecord(
+ CacheElement.newWatermark(watermark.getTimestamp()));
+ }
+
+ } else {
+ while (!areBroadcastVariablesReady()) {
+ mailboxExecutor.yield();
+ }
+ watermarkConsumer.accept(watermark);
+ }
+ }
+
+ /**
+ * extracts common processing logic in subclasses' endInput(...).
+ *
+ * @param inputIndex input id, starts from zero.
+ * @param elementConsumer the consumer function of StreamRecord, i.e.,
+ * operator.processElement(...).
+ * @param watermarkConsumer the consumer function of WaterMark, i.e.,
+ * operator.processWatermark(...).
+ * @throws Exception possible exception.
+ */
+ @SuppressWarnings("rawtypes")
+ protected void endInputX(
+ int inputIndex,
+ ThrowingConsumer<StreamRecord, Exception> elementConsumer,
+ ThrowingConsumer<Watermark, Exception> watermarkConsumer)
+ throws Exception {
+ while (!areBroadcastVariablesReady()) {
+ mailboxExecutor.yield();
+ }
+ if (hasPendingElements[inputIndex]) {
+ processPendingElementsAndWatermarks(inputIndex, elementConsumer,
watermarkConsumer);
+ hasPendingElements[inputIndex] = false;
+ }
+ }
+
+ /**
+ * processes the pending elements that are cached by {@link
DataCacheWriter}.
+ *
+ * @param inputIndex input id, starts from zero.
+ * @param elementConsumer the consumer function of StreamRecord, i.e.,
+ * operator.processElement(...).
+ * @param watermarkConsumer the consumer function of WaterMark, i.e.,
+ * operator.processWatermark(...).
+ * @throws Exception possible exception.
+ */
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ private void processPendingElementsAndWatermarks(
+ int inputIndex,
+ ThrowingConsumer<StreamRecord, Exception> elementConsumer,
+ ThrowingConsumer<Watermark, Exception> watermarkConsumer)
+ throws Exception {
+ dataCacheWriters[inputIndex].finishCurrentSegment();
+ List<Segment> pendingSegments =
dataCacheWriters[inputIndex].getFinishSegments();
+ if (pendingSegments.size() != 0) {
+ DataCacheReader dataCacheReader =
+ new DataCacheReader<>(
+ new CacheElementTypeInfo<>(inTypes[inputIndex])
+
.createSerializer(containingTask.getExecutionConfig()),
+ basePath.getFileSystem(),
+ pendingSegments);
+ while (dataCacheReader.hasNext()) {
+ CacheElement cacheElement = (CacheElement)
dataCacheReader.next();
+ switch (cacheElement.getType()) {
+ case RECORD:
+ elementConsumer.accept(new
StreamRecord(cacheElement.getRecord()));
+ break;
+ case WATERMARK:
+ watermarkConsumer.accept(new
Watermark(cacheElement.getWatermark()));
+ break;
+ default:
+ throw new RuntimeException(
+ "Unsupported CacheElement type: " +
cacheElement.getType());
+ }
+ }
+ }
+ }
+
+ @Override
+ public void open() throws Exception {
+ wrappedOperator.open();
+ }
+
+ @Override
+ public void close() throws Exception {
+ wrappedOperator.close();
+ for (String name : broadcastStreamNames) {
+ BroadcastContext.remove(name + "-" + indexOfSubtask);
+ }
+ }
+
+ @Override
+ public void finish() throws Exception {
+ wrappedOperator.finish();
+ }
+
+ @Override
+ public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
+ wrappedOperator.prepareSnapshotPreBarrier(checkpointId);
+ }
+
+ @Override
+ public void initializeState(StreamTaskStateInitializer
streamTaskStateManager)
+ throws Exception {
+ final TypeSerializer<?> keySerializer =
+
streamConfig.getStateKeySerializer(containingTask.getUserCodeClassLoader());
+
+ StreamOperatorStateContext streamOperatorStateContext =
+ streamTaskStateManager.streamOperatorStateContext(
+ getOperatorID(),
+ getClass().getSimpleName(),
+ parameters.getProcessingTimeService(),
+ this,
+ keySerializer,
+ containingTask.getCancelables(),
+ metrics,
+
streamConfig.getManagedMemoryFractionOperatorUseCaseOfSlot(
+ ManagedMemoryUseCase.STATE_BACKEND,
+ containingTask
+ .getEnvironment()
+ .getTaskManagerInfo()
+ .getConfiguration(),
+ containingTask.getUserCodeClassLoader()),
+ false);
+ stateHandler =
+ new StreamOperatorStateHandler(
+ streamOperatorStateContext,
+ containingTask.getExecutionConfig(),
+ containingTask.getCancelables());
+ stateHandler.initializeOperatorState(this);
+
+ timeServiceManager =
streamOperatorStateContext.internalTimerServiceManager();
+
+ broadcastVariablesReady = false;
+
+ wrappedOperator.initializeState(
+ (operatorID,
+ operatorClassName,
+ processingTimeService,
+ keyContext,
+ keySerializerX,
+ streamTaskCloseableRegistry,
+ metricGroup,
+ managedMemoryFraction,
+ isUsingCustomRawKeyedState) ->
+ new ProxyStreamOperatorStateContext(
+ streamOperatorStateContext, "wrapped-"));
+ }
+
+ @Override
+ public OperatorSnapshotFutures snapshotState(
+ long checkpointId,
+ long timestamp,
+ CheckpointOptions checkpointOptions,
+ CheckpointStreamFactory storageLocation)
+ throws Exception {
+ return stateHandler.snapshotState(
+ this,
+ Optional.ofNullable(timeServiceManager),
+ streamConfig.getOperatorName(),
+ checkpointId,
+ timestamp,
+ checkpointOptions,
+ storageLocation,
+ false);
+ }
+
+ @Override
+ @SuppressWarnings("unchecked, rawtypes")
+ public void initializeState(StateInitializationContext
stateInitializationContext)
+ throws Exception {
+ List<StatePartitionStreamProvider> inputs =
+ IteratorUtils.toList(
+
stateInitializationContext.getRawOperatorStateInputs().iterator());
+ Preconditions.checkState(
+ inputs.size() < 2, "The input from raw operator state should
be one or zero.");
+ if (inputs.size() == 0) {
+ for (int i = 0; i < numInputs; i++) {
+ dataCacheWriters[i] =
+ new DataCacheWriter(
+ new CacheElementTypeInfo<>(inTypes[i])
+
.createSerializer(containingTask.getExecutionConfig()),
+ basePath.getFileSystem(),
+ OperatorUtils.createDataCacheFileGenerator(
+ basePath, "cache",
streamConfig.getOperatorID()));
+ }
+ } else {
+ InputStream inputStream = inputs.get(0).getStream();
+ DataInputStream dis =
+ new DataInputStream(new
NonClosingInputStreamDecorator(inputStream));
+ Preconditions.checkState(dis.readInt() == numInputs, "Number of
input is wrong.");
+ for (int i = 0; i < numInputs; i++) {
+ DataCacheSnapshot dataCacheSnapshot =
+ DataCacheSnapshot.recover(
+ inputStream,
+ basePath.getFileSystem(),
+ OperatorUtils.createDataCacheFileGenerator(
+ basePath, "cache",
streamConfig.getOperatorID()));
+ dataCacheWriters[i] =
+ new DataCacheWriter(
+ new CacheElementTypeInfo<>(inTypes[i])
+
.createSerializer(containingTask.getExecutionConfig()),
+ basePath.getFileSystem(),
+ OperatorUtils.createDataCacheFileGenerator(
+ basePath, "cache",
streamConfig.getOperatorID()),
+ dataCacheSnapshot.getSegments());
+ }
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public void snapshotState(StateSnapshotContext stateSnapshotContext)
throws Exception {
+ if (wrappedOperator instanceof
StreamOperatorStateHandler.CheckpointedStreamOperator) {
+ ((CheckpointedStreamOperator)
wrappedOperator).snapshotState(stateSnapshotContext);
+ }
+
+ OperatorStateCheckpointOutputStream checkpointOutputStream =
+ stateSnapshotContext.getRawOperatorStateOutput();
+ checkpointOutputStream.startNewPartition();
+ try (DataOutputStream dos =
+ new DataOutputStream(new
NonClosingOutpusStreamDecorator(checkpointOutputStream))) {
+ dos.writeInt(numInputs);
+ }
+ for (int i = 0; i < numInputs; i++) {
+ dataCacheWriters[i].finishCurrentSegment();
+ DataCacheSnapshot dataCacheSnapshot =
+ new DataCacheSnapshot(
+ basePath.getFileSystem(),
+ null,
+ dataCacheWriters[i].getFinishSegments());
+ dataCacheSnapshot.writeTo(checkpointOutputStream);
+ }
+ }
+
+ @Override
+ public void setKeyContextElement1(StreamRecord<?> record) throws Exception
{
+ wrappedOperator.setKeyContextElement1(record);
+ }
+
+ @Override
+ public void setKeyContextElement2(StreamRecord<?> record) throws Exception
{
+ wrappedOperator.setKeyContextElement2(record);
+ }
+
+ @Override
+ public OperatorMetricGroup getMetricGroup() {
+ return wrappedOperator.getMetricGroup();
+ }
+
+ @Override
+ public OperatorID getOperatorID() {
+ return wrappedOperator.getOperatorID();
+ }
+
+ @Override
+ public void notifyCheckpointComplete(long checkpointId) throws Exception {
+ wrappedOperator.notifyCheckpointComplete(checkpointId);
+ }
+
+ @Override
+ public void notifyCheckpointAborted(long checkpointId) throws Exception {
+ wrappedOperator.notifyCheckpointAborted(checkpointId);
+ }
+
+ @Override
+ public void setCurrentKey(Object key) {
+ wrappedOperator.setCurrentKey(key);
+ }
+
+ @Override
+ public Object getCurrentKey() {
+ return wrappedOperator.getCurrentKey();
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperator.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperator.java
new file mode 100644
index 0000000..170fccc
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperator.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractInput;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.Input;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+/** The operator that process all broadcast inputs and stores them in {@link
BroadcastContext}. */
+public class BroadcastVariableReceiverOperator<OUT> extends
AbstractStreamOperatorV2<OUT>
+ implements MultipleInputStreamOperator<OUT>, BoundedMultiInput,
Serializable {
+
+ /** names of the broadcast data streams. */
+ private final String[] broadcastStreamNames;
+
+ /** output types of input data streams. */
+ private final TypeInformation<?>[] inTypes;
+
+ /** input list of the multi-input operator. */
+ @SuppressWarnings("rawtypes")
+ private final List<Input> inputList;
+
+ /** caches of the broadcast inputs. */
+ @SuppressWarnings("rawtypes")
+ private final List[] caches;
+
+ /** state storage of the broadcast inputs. */
+ private ListState<?>[] cacheStates;
+
+ /** cacheReady state storage of the broadcast inputs. */
+ private ListState<Boolean>[] cacheReadyStates;
+
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ BroadcastVariableReceiverOperator(
+ StreamOperatorParameters<OUT> parameters,
+ String[] broadcastStreamNames,
+ TypeInformation<?>[] inTypes) {
+ super(parameters, broadcastStreamNames.length);
+ this.broadcastStreamNames = broadcastStreamNames;
+ this.inTypes = inTypes;
+ inputList = new ArrayList<>();
+ for (int i = 0; i < inTypes.length; i++) {
+ inputList.add(new ProxyInput(this, i + 1));
+ }
+ this.caches = new List[inTypes.length];
+ for (int i = 0; i < inTypes.length; i++) {
+ caches[i] = new ArrayList<>();
+ }
+ this.cacheStates = new ListState[inTypes.length];
+ this.cacheReadyStates = new ListState[inTypes.length];
+ }
+
+ @Override
+ public List<Input> getInputs() {
+ return inputList;
+ }
+
+ @Override
+ public void endInput(int i) {
+ BroadcastContext.markCacheFinished(
+ broadcastStreamNames[i - 1] + "-" +
getRuntimeContext().getIndexOfThisSubtask());
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public void snapshotState(StateSnapshotContext context) throws Exception {
+ super.snapshotState(context);
+ for (int i = 0; i < inTypes.length; i++) {
+ cacheStates[i].clear();
+ cacheStates[i].addAll(caches[i]);
+ cacheReadyStates[i].clear();
+ boolean isCacheFinished =
+ BroadcastContext.isCacheFinished(
+ broadcastStreamNames[i]
+ + "-"
+ +
getRuntimeContext().getIndexOfThisSubtask());
+ cacheReadyStates[i].add(isCacheFinished);
+ }
+ }
+
+ @Override
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ public void initializeState(StateInitializationContext context) throws
Exception {
+ super.initializeState(context);
+ for (int i = 0; i < inTypes.length; i++) {
+ cacheStates[i] =
+ context.getOperatorStateStore()
+ .getListState(new
ListStateDescriptor("cache_data_" + i, inTypes[i]));
+ caches[i] = IteratorUtils.toList(cacheStates[i].get().iterator());
+
+ cacheReadyStates[i] =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "cache_ready_state_" + i,
+ BasicTypeInfo.BOOLEAN_TYPE_INFO));
+ boolean cacheReady =
+ OperatorStateUtils.getUniqueElement(
+ cacheReadyStates[i], "cache_ready_state_"
+ i)
+ .orElse(false);
+ BroadcastContext.putBroadcastVariable(
+ broadcastStreamNames[i] + "-" +
getRuntimeContext().getIndexOfThisSubtask(),
+ Tuple2.of(cacheReady, caches[i]));
+ }
+ }
+
+ private class ProxyInput<IN, OT> extends AbstractInput<IN, OT> {
+
+ public ProxyInput(AbstractStreamOperatorV2<OT> owner, int inputId) {
+ super(owner, inputId);
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public void processElement(StreamRecord<IN> element) {
+ (caches[inputId - 1]).add(element.getValue());
+ }
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorFactory.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorFactory.java
new file mode 100644
index 0000000..75721c7
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorFactory.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+
+import java.io.Serializable;
+
+/** Factory class for {@link BroadcastVariableReceiverOperator}. */
+public class BroadcastVariableReceiverOperatorFactory<OUT>
+ extends AbstractStreamOperatorFactory<OUT> implements Serializable {
+
+ /** names of the broadcast data streams. */
+ private final String[] broadcastNames;
+
+ /** types of the broadcast data streams. */
+ private final TypeInformation<?>[] inTypes;
+
+ public BroadcastVariableReceiverOperatorFactory(
+ String[] broadcastNames, TypeInformation<?>[] inTypes) {
+ this.broadcastNames = broadcastNames;
+ this.inTypes = inTypes;
+ }
+
+ @Override
+ public <T extends StreamOperator<OUT>> T createStreamOperator(
+ StreamOperatorParameters<OUT> parameters) {
+ return (T) new BroadcastVariableReceiverOperator(parameters,
broadcastNames, inTypes);
+ }
+
+ @Override
+ public Class<? extends StreamOperator> getStreamOperatorClass(ClassLoader
classLoader) {
+ return BroadcastVariableReceiverOperator.class;
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
new file mode 100644
index 0000000..2e3f88d
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for {@link AbstractBroadcastWrapperOperator}. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+
+ /** names of the broadcast data streams. */
+ private final String[] broadcastStreamNames;
+
+ /** types of input data streams. */
+ private final TypeInformation<?>[] inTypes;
+
+ /** whether each input is blocked or not. */
+ private final boolean[] isBlocked;
+
+ @VisibleForTesting
+ public BroadcastWrapper(String[] broadcastStreamNames,
TypeInformation<?>[] inTypes) {
+ this(broadcastStreamNames, inTypes, new boolean[inTypes.length]);
+ }
+
+ public BroadcastWrapper(
+ String[] broadcastStreamNames, TypeInformation<?>[] inTypes,
boolean[] isBlocked) {
+ Preconditions.checkArgument(inTypes.length == isBlocked.length);
+ this.broadcastStreamNames = broadcastStreamNames;
+ this.inTypes = inTypes;
+ this.isBlocked = isBlocked;
+ }
+
+ @Override
+ public StreamOperator<T> wrap(
+ StreamOperatorParameters<T> operatorParameters,
+ StreamOperatorFactory<T> operatorFactory) {
+ Class<? extends StreamOperator> operatorClass =
+
operatorFactory.getStreamOperatorClass(getClass().getClassLoader());
+ if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+ return new OneInputBroadcastWrapperOperator<>(
+ operatorParameters, operatorFactory, broadcastStreamNames,
inTypes, isBlocked);
+ } else if
(TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+ return new TwoInputBroadcastWrapperOperator<>(
+ operatorParameters, operatorFactory, broadcastStreamNames,
inTypes, isBlocked);
+ } else {
+ throw new UnsupportedOperationException(
+ "Unsupported operator class for with-broadcast wrapper: "
+ operatorClass);
+ }
+ }
+
+ @Override
+ public <KEY> KeySelector<T, KEY> wrapKeySelector(KeySelector<T, KEY>
keySelector) {
+ return keySelector;
+ }
+
+ @Override
+ public StreamPartitioner<T> wrapStreamPartitioner(StreamPartitioner<T>
streamPartitioner) {
+ return streamPartitioner;
+ }
+
+ @Override
+ public OutputTag<T> wrapOutputTag(OutputTag<T> outputTag) {
+ return outputTag;
+ }
+
+ @Override
+ public TypeInformation<T> getWrappedTypeInfo(TypeInformation<T> typeInfo) {
+ return typeInfo;
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
new file mode 100644
index 0000000..f1ffe00
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.operator.OperatorUtils;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+
+/** Wrapper for {@link OneInputStreamOperator}. */
+public class OneInputBroadcastWrapperOperator<IN, OUT>
+ extends AbstractBroadcastWrapperOperator<OUT,
OneInputStreamOperator<IN, OUT>>
+ implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
+
+ OneInputBroadcastWrapperOperator(
+ StreamOperatorParameters<OUT> parameters,
+ StreamOperatorFactory<OUT> operatorFactory,
+ String[] broadcastStreamNames,
+ TypeInformation<?>[] inTypes,
+ boolean[] isBlocking) {
+ super(parameters, operatorFactory, broadcastStreamNames, inTypes,
isBlocking);
+ }
+
+ @Override
+ public void processElement(StreamRecord<IN> streamRecord) throws Exception
{
+ processElementX(
+ streamRecord,
+ 0,
+ wrappedOperator::processElement,
+ wrappedOperator::processWatermark);
+ }
+
+ @Override
+ public void endInput() throws Exception {
+ endInputX(0, wrappedOperator::processElement,
wrappedOperator::processWatermark);
+ OperatorUtils.processOperatorOrUdfIfSatisfy(
+ wrappedOperator, BoundedOneInput.class,
BoundedOneInput::endInput);
+ }
+
+ @Override
+ public void processWatermark(Watermark watermark) throws Exception {
+ processWatermarkX(
+ watermark, 0, wrappedOperator::processElement,
wrappedOperator::processWatermark);
+ }
+
+ @Override
+ public void processWatermarkStatus(WatermarkStatus watermarkStatus) throws
Exception {
+ wrappedOperator.processWatermarkStatus(watermarkStatus);
+ }
+
+ @Override
+ public void processLatencyMarker(LatencyMarker latencyMarker) throws
Exception {
+ wrappedOperator.processLatencyMarker(latencyMarker);
+ }
+
+ @Override
+ public void setKeyContextElement(StreamRecord<IN> streamRecord) throws
Exception {
+ wrappedOperator.setKeyContextElement(streamRecord);
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperator.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperator.java
new file mode 100644
index 0000000..07871d4
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperator.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.operator.OperatorUtils;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+
+/** Wrapper for {@link TwoInputStreamOperator}. */
+public class TwoInputBroadcastWrapperOperator<IN1, IN2, OUT>
+ extends AbstractBroadcastWrapperOperator<OUT,
TwoInputStreamOperator<IN1, IN2, OUT>>
+ implements TwoInputStreamOperator<IN1, IN2, OUT>, BoundedMultiInput {
+
+ TwoInputBroadcastWrapperOperator(
+ StreamOperatorParameters<OUT> parameters,
+ StreamOperatorFactory<OUT> operatorFactory,
+ String[] broadcastStreamNames,
+ TypeInformation<?>[] inTypes,
+ boolean[] isBlocking) {
+ super(parameters, operatorFactory, broadcastStreamNames, inTypes,
isBlocking);
+ }
+
+ @Override
+ public void processElement1(StreamRecord<IN1> streamRecord) throws
Exception {
+ processElementX(
+ streamRecord,
+ 0,
+ wrappedOperator::processElement1,
+ wrappedOperator::processWatermark1);
+ }
+
+ @Override
+ public void processElement2(StreamRecord<IN2> streamRecord) throws
Exception {
+ processElementX(
+ streamRecord,
+ 1,
+ wrappedOperator::processElement2,
+ wrappedOperator::processWatermark2);
+ }
+
+ @Override
+ public void endInput(int inputId) throws Exception {
+ if (inputId == 1) {
+ endInputX(
+ inputId - 1,
+ wrappedOperator::processElement1,
+ wrappedOperator::processWatermark1);
+ } else {
+ endInputX(
+ inputId - 1,
+ wrappedOperator::processElement2,
+ wrappedOperator::processWatermark2);
+ }
+ OperatorUtils.processOperatorOrUdfIfSatisfy(
+ wrappedOperator,
+ BoundedMultiInput.class,
+ boundedMultipleInput ->
boundedMultipleInput.endInput(inputId));
+ }
+
+ @Override
+ public void processWatermark1(Watermark watermark) throws Exception {
+ processWatermarkX(
+ watermark, 0, wrappedOperator::processElement1,
wrappedOperator::processWatermark1);
+ }
+
+ @Override
+ public void processWatermark2(Watermark watermark) throws Exception {
+ processWatermarkX(
+ watermark, 1, wrappedOperator::processElement2,
wrappedOperator::processWatermark2);
+ }
+
+ @Override
+ public void processLatencyMarker1(LatencyMarker latencyMarker) throws
Exception {
+ wrappedOperator.processLatencyMarker1(latencyMarker);
+ }
+
+ @Override
+ public void processLatencyMarker2(LatencyMarker latencyMarker) throws
Exception {
+ wrappedOperator.processLatencyMarker2(latencyMarker);
+ }
+
+ @Override
+ public void processWatermarkStatus1(WatermarkStatus watermarkStatus)
throws Exception {
+ wrappedOperator.processWatermarkStatus1(watermarkStatus);
+ }
+
+ @Override
+ public void processWatermarkStatus2(WatermarkStatus watermarkStatus)
throws Exception {
+ wrappedOperator.processWatermarkStatus2(watermarkStatus);
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/typeinfo/CacheElement.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/typeinfo/CacheElement.java
new file mode 100644
index 0000000..63cf6c7
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/typeinfo/CacheElement.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.typeinfo;
+
+/**
+ * The wrapper class for possible cached elements used in {@link
+ *
org.apache.flink.ml.common.broadcast.operator.AbstractBroadcastWrapperOperator}.
It could be
+ * either {@link org.apache.flink.streaming.api.watermark.Watermark}, {@link
+ * org.apache.flink.streaming.runtime.streamrecord.StreamRecord}, etc.
+ *
+ * @param <T> the record type.
+ */
+public class CacheElement<T> {
+ private T record;
+ private long watermark;
+ private Type type;
+
+ public CacheElement(T record, long watermark, Type type) {
+ this.record = record;
+ this.watermark = watermark;
+ this.type = type;
+ }
+
+ public static <T> CacheElement<T> newRecord(T record) {
+ return new CacheElement<>(record, -1, Type.RECORD);
+ }
+
+ public static <T> CacheElement<T> newWatermark(long watermark) {
+ return new CacheElement<>(null, watermark, Type.WATERMARK);
+ }
+
+ public T getRecord() {
+ return record;
+ }
+
+ public void setRecord(T record) {
+ this.record = record;
+ }
+
+ public long getWatermark() {
+ return watermark;
+ }
+
+ public void setWatermark(long watermark) {
+ this.watermark = watermark;
+ }
+
+ public Type getType() {
+ return type;
+ }
+
+ public void setType(Type type) {
+ this.type = type;
+ }
+
+ /** The type of cached elements. */
+ public enum Type {
+ /** record type. */
+ RECORD,
+ /** watermark type. */
+ WATERMARK,
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/typeinfo/CacheElementSerializer.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/typeinfo/CacheElementSerializer.java
new file mode 100644
index 0000000..ed600c1
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/typeinfo/CacheElementSerializer.java
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.typeinfo;
+
+import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.ml.common.broadcast.typeinfo.CacheElement.Type;
+
+import java.io.IOException;
+import java.util.Objects;
+
+/**
+ * TypeSerializer for {@link CacheElement}.
+ *
+ * @param <T> the record type.
+ */
+public class CacheElementSerializer<T> extends TypeSerializer<CacheElement<T>>
{
+
+ private final TypeSerializer<T> recordSerializer;
+
+ public CacheElementSerializer(TypeSerializer<T> recordSerializer) {
+ this.recordSerializer = recordSerializer;
+ }
+
+ @Override
+ public boolean isImmutableType() {
+ return false;
+ }
+
+ @Override
+ public TypeSerializer<CacheElement<T>> duplicate() {
+ return new CacheElementSerializer<>(recordSerializer.duplicate());
+ }
+
+ @Override
+ public CacheElement<T> createInstance() {
+ return null;
+ }
+
+ @Override
+ public CacheElement<T> copy(CacheElement<T> from) {
+ switch (from.getType()) {
+ case RECORD:
+ return
CacheElement.newRecord(recordSerializer.copy(from.getRecord()));
+ case WATERMARK:
+ return CacheElement.newWatermark(from.getWatermark());
+ default:
+ throw new RuntimeException(
+ "Unsupported Record or Watermark type " +
from.getType());
+ }
+ }
+
+ @Override
+ public CacheElement<T> copy(CacheElement<T> from, CacheElement<T> reuse) {
+ switch (from.getType()) {
+ case RECORD:
+ if (reuse.getRecord() != null) {
+ recordSerializer.copy(from.getRecord(), reuse.getRecord());
+ } else {
+ reuse.setRecord(recordSerializer.copy(from.getRecord()));
+ }
+ break;
+ case WATERMARK:
+ reuse.setWatermark(from.getWatermark());
+ break;
+ default:
+ throw new RuntimeException(
+ "Unsupported Record or Watermark type " +
from.getType());
+ }
+
+ return reuse;
+ }
+
+ @Override
+ public int getLength() {
+ return -1;
+ }
+
+ @Override
+ public void serialize(CacheElement<T> record, DataOutputView target)
throws IOException {
+ target.writeByte((byte) record.getType().ordinal());
+
+ switch (record.getType()) {
+ case RECORD:
+ recordSerializer.serialize(record.getRecord(), target);
+ break;
+ case WATERMARK:
+ LongSerializer.INSTANCE.serialize(record.getWatermark(),
target);
+ break;
+ default:
+ throw new RuntimeException(
+ "Unsupported Record or Watermark type " +
record.getType());
+ }
+ }
+
+ @Override
+ public CacheElement<T> deserialize(DataInputView source) throws
IOException {
+ int type = source.readByte();
+ switch (CacheElement.Type.values()[type]) {
+ case RECORD:
+ T value = recordSerializer.deserialize(source);
+ return CacheElement.newRecord(value);
+ case WATERMARK:
+ long watermark = LongSerializer.INSTANCE.deserialize(source);
+ return CacheElement.newWatermark(watermark);
+ default:
+ throw new RuntimeException("Unsupported Record or Watermark
type " + type);
+ }
+ }
+
+ @Override
+ public CacheElement<T> deserialize(CacheElement<T> reuse, DataInputView
source)
+ throws IOException {
+ int type = source.readByte();
+ switch (CacheElement.Type.values()[type]) {
+ case RECORD:
+ reuse.setType(Type.RECORD);
+ reuse.setRecord(recordSerializer.deserialize(source));
+ break;
+ case WATERMARK:
+ reuse.setType(Type.WATERMARK);
+
reuse.setWatermark(LongSerializer.INSTANCE.deserialize(source));
+ break;
+ default:
+ throw new RuntimeException("Unsupported Record or Watermark
type " + type);
+ }
+ return reuse;
+ }
+
+ @Override
+ public void copy(DataInputView source, DataOutputView target) throws
IOException {
+ CacheElement<T> cacheElement = deserialize(source);
+ serialize(cacheElement, target);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+
+ if (obj == null || getClass() != obj.getClass()) {
+ return false;
+ }
+ CacheElementSerializer<?> that = (CacheElementSerializer<?>) obj;
+ return Objects.equals(recordSerializer, that.recordSerializer);
+ }
+
+ @Override
+ public int hashCode() {
+ return recordSerializer != null ? recordSerializer.hashCode() : 0;
+ }
+
+ @Override
+ public TypeSerializerSnapshot<CacheElement<T>> snapshotConfiguration() {
+ return new CacheElementSerializerSnapshot<>();
+ }
+
+ /** The serializer snapshot class for {@link CacheElementSerializer}. */
+ private static final class CacheElementSerializerSnapshot<T>
+ extends CompositeTypeSerializerSnapshot<CacheElement<T>,
CacheElementSerializer<T>> {
+
+ private static final int CURRENT_VERSION = 1;
+
+ public CacheElementSerializerSnapshot() {
+ super(CacheElementSerializer.class);
+ }
+
+ @Override
+ protected int getCurrentOuterSnapshotVersion() {
+ return CURRENT_VERSION;
+ }
+
+ @Override
+ protected TypeSerializer<?>[] getNestedSerializers(
+ CacheElementSerializer<T> tIterationRecordSerializer) {
+ return new TypeSerializer[]
{tIterationRecordSerializer.recordSerializer};
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ protected CacheElementSerializer<T>
createOuterSerializerWithNestedSerializers(
+ TypeSerializer<?>[] typeSerializers) {
+ TypeSerializer<T> elementSerializer = (TypeSerializer<T>)
typeSerializers[0];
+ return new CacheElementSerializer<>(elementSerializer);
+ }
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/typeinfo/CacheElementTypeInfo.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/typeinfo/CacheElementTypeInfo.java
new file mode 100644
index 0000000..55dc4de
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/typeinfo/CacheElementTypeInfo.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.typeinfo;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+
+import java.util.Objects;
+
+/**
+ * TypeInformation for {@link CacheElement}.
+ *
+ * @param <T> the record type.
+ */
+public class CacheElementTypeInfo<T> extends TypeInformation<CacheElement<T>> {
+
+ private final TypeInformation<T> recordTypeInfo;
+
+ public CacheElementTypeInfo(TypeInformation<T> recordTypeInfo) {
+ this.recordTypeInfo = recordTypeInfo;
+ }
+
+ @Override
+ public boolean isBasicType() {
+ return false;
+ }
+
+ @Override
+ public boolean isTupleType() {
+ return false;
+ }
+
+ @Override
+ public int getArity() {
+ return 1;
+ }
+
+ @Override
+ public int getTotalFields() {
+ return 1;
+ }
+
+ @Override
+ public Class<CacheElement<T>> getTypeClass() {
+ return (Class) CacheElement.class;
+ }
+
+ @Override
+ public boolean isKeyType() {
+ return false;
+ }
+
+ @Override
+ public TypeSerializer<CacheElement<T>> createSerializer(ExecutionConfig
config) {
+ return new
CacheElementSerializer<>(recordTypeInfo.createSerializer(config));
+ }
+
+ @Override
+ public String toString() {
+ return "RecordOrWatermark Type";
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+
+ if (null == obj || getClass() != obj.getClass()) {
+ return false;
+ }
+
+ CacheElementTypeInfo<T> that = (CacheElementTypeInfo<T>) obj;
+ return Objects.equals(recordTypeInfo, that.recordTypeInfo);
+ }
+
+ @Override
+ public int hashCode() {
+ return recordTypeInfo != null ? recordTypeInfo.hashCode() : 0;
+ }
+
+ @Override
+ public boolean canEqual(Object obj) {
+ return obj instanceof CacheElementTypeInfo;
+ }
+}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/BroadcastUtilsTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/BroadcastUtilsTest.java
new file mode 100644
index 0000000..08f9498
--- /dev/null
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/BroadcastUtilsTest.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.iteration.config.IterationOptions;
+import org.apache.flink.ml.common.broadcast.operator.TestOneInputOp;
+import org.apache.flink.ml.common.broadcast.operator.TestTwoInputOp;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.streaming.api.CheckpointingMode;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/** Tests the {@link BroadcastUtils}. */
+public class BroadcastUtilsTest {
+
+ @Rule public TemporaryFolder tempFolder = new TemporaryFolder();
+
+ private static final int NUM_RECORDS_PER_PARTITION = 10;
+
+ private static final int NUM_TM = 2;
+
+ private static final int NUM_SLOT = 2;
+
+ private static final String[] BROADCAST_NAMES = new String[] {"source1",
"source2"};
+
+ private static final List<Integer> BROADCAST_INPUT =
+ IntStream.range(0, NUM_TM * NUM_SLOT * NUM_RECORDS_PER_PARTITION)
+ .boxed()
+ .collect(Collectors.toList());
+
+ private MiniClusterConfiguration createMiniClusterConfiguration() throws
IOException {
+ Configuration configuration = new Configuration();
+ configuration.set(RestOptions.PORT, 18082);
+ configuration.set(
+ IterationOptions.DATA_CACHE_PATH,
+ "file://" + tempFolder.newFolder().getAbsolutePath());
+ configuration.set(
+
ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+ return new MiniClusterConfiguration.Builder()
+ .setConfiguration(configuration)
+ .setNumTaskManagers(NUM_TM)
+ .setNumSlotsPerTaskManager(NUM_SLOT)
+ .build();
+ }
+
+ @Test
+ public void testOneInputGraph() throws Exception {
+ try (MiniCluster miniCluster = new
MiniCluster(createMiniClusterConfiguration())) {
+ miniCluster.start();
+ JobGraph jobGraph = getJobGraph(1);
+ miniCluster.executeJobBlocking(jobGraph);
+ }
+ }
+
+ @Test
+ public void testTwoInputGraph() throws Exception {
+ try (MiniCluster miniCluster = new
MiniCluster(createMiniClusterConfiguration())) {
+ miniCluster.start();
+ JobGraph jobGraph = getJobGraph(2);
+ miniCluster.executeJobBlocking(jobGraph);
+ }
+ }
+
+ private JobGraph getJobGraph(int numNonBroadcastInputs) {
+ StreamExecutionEnvironment env =
+ StreamExecutionEnvironment.getExecutionEnvironment(
+ new Configuration() {
+ {
+ this.set(
+ ExecutionCheckpointingOptions
+
.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
+ true);
+ }
+ });
+ env.enableCheckpointing(500, CheckpointingMode.EXACTLY_ONCE);
+ env.setParallelism(NUM_SLOT * NUM_TM);
+
+ DataStream<Integer> source1 = env.addSource(new
TestSource(NUM_RECORDS_PER_PARTITION));
+ DataStream<Integer> source2 = env.addSource(new
TestSource(NUM_RECORDS_PER_PARTITION));
+ HashMap<String, DataStream<?>> bcStreamsMap = new HashMap<>();
+ bcStreamsMap.put(BROADCAST_NAMES[0], source1);
+ bcStreamsMap.put(BROADCAST_NAMES[1], source2);
+
+ List<DataStream<?>> inputList = new ArrayList<>(1);
+ // create a deadlock.
+ inputList.add(source1);
+ for (int i = 0; i < numNonBroadcastInputs - 1; i++) {
+ inputList.add(env.addSource(new
TestSource(NUM_RECORDS_PER_PARTITION)));
+ }
+
+ Function<List<DataStream<?>>, DataStream<Integer>> func =
getFunc(numNonBroadcastInputs);
+
+ DataStream<Integer> result =
+ BroadcastUtils.withBroadcastStream(inputList, bcStreamsMap,
func);
+
+ List<Integer> expectedNumSequence =
+ new ArrayList<>(
+ NUM_TM * NUM_SLOT * NUM_RECORDS_PER_PARTITION *
numNonBroadcastInputs);
+ for (int i = 0; i < NUM_TM * NUM_SLOT * NUM_RECORDS_PER_PARTITION;
i++) {
+ for (int j = 0; j < numNonBroadcastInputs; j++) {
+ expectedNumSequence.add(i);
+ }
+ }
+ result.addSink(new TestSink(expectedNumSequence)).setParallelism(1);
+
+ return env.getStreamGraph().getJobGraph();
+ }
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ private static Function<List<DataStream<?>>, DataStream<Integer>>
getFunc(int numInputs) {
+ if (numInputs == 1) {
+ return dataStreams -> {
+ DataStream input = dataStreams.get(0);
+ return input.transform(
+ "one-input",
+ BasicTypeInfo.INT_TYPE_INFO,
+ new TestOneInputOp(
+ new AbstractRichFunction() {},
+ BROADCAST_NAMES,
+ Arrays.asList(BROADCAST_INPUT,
BROADCAST_INPUT)))
+ .name("broadcast");
+ };
+ } else if (numInputs == 2) {
+ return dataStreams -> {
+ DataStream input1 = dataStreams.get(0);
+ DataStream input2 = dataStreams.get(1);
+ return input1.connect(input2)
+ .transform(
+ "two-input",
+ BasicTypeInfo.INT_TYPE_INFO,
+ new TestTwoInputOp(
+ new AbstractRichFunction() {},
+ BROADCAST_NAMES,
+ Arrays.asList(BROADCAST_INPUT,
BROADCAST_INPUT)))
+ .name("broadcast");
+ };
+ }
+ return null;
+ }
+}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/TestSink.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/TestSink.java
new file mode 100644
index 0000000..e1043cc
--- /dev/null
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/TestSink.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import
org.apache.flink.ml.common.broadcast.operator.BroadcastVariableReceiverOperatorTest;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.List;
+
+/**
+ * Utility class to check the result of a sink. It throws an exception if the
records received are
+ * not as expected.
+ */
+public class TestSink extends RichSinkFunction<Integer> implements
CheckpointedFunction {
+
+ private final List<Integer> expectReceivedRecords;
+
+ private List<Integer> receivedRecords;
+
+ private ListState<Integer> receivedRecordsState;
+
+ public TestSink(List<Integer> expectReceivedRecords) {
+ this.expectReceivedRecords = expectReceivedRecords;
+ }
+
+ @Override
+ public void open(Configuration parameters) throws Exception {
+ super.open(parameters);
+ }
+
+ @Override
+ public void invoke(Integer value, Context context) {
+ receivedRecords.add(value);
+ }
+
+ @Override
+ public void finish() {
+
BroadcastVariableReceiverOperatorTest.compareLists(expectReceivedRecords,
receivedRecords);
+ }
+
+ @Override
+ public void close() {}
+
+ @Override
+ public void snapshotState(FunctionSnapshotContext functionSnapshotContext)
throws Exception {
+ this.receivedRecordsState.clear();
+ this.receivedRecordsState.addAll(receivedRecords);
+ }
+
+ @Override
+ public void initializeState(FunctionInitializationContext
functionInitializationContext)
+ throws Exception {
+ receivedRecordsState =
+ functionInitializationContext
+ .getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "receivedRecords",
BasicTypeInfo.INT_TYPE_INFO));
+ receivedRecords =
IteratorUtils.toList(receivedRecordsState.get().iterator());
+ }
+}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/TestSource.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/TestSource.java
new file mode 100644
index 0000000..46e1880
--- /dev/null
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/TestSource.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+
+import java.util.Iterator;
+
+/**
+ * Utility class that generates int stream and also throws exceptions to test
the fail over. In
+ * detail, given ${numElementsPerPartition}, this class generate a number
sequence with elements
+ * range in [0, StreamExecutionEnvironment.getParallelism() *
numElementsPerPartition).
+ *
+ * <p>For example, when the parallelism is 2 and the
${numElementsPerPartition} is 5, this class
+ * generates {0,1,2,3,4,5,6,7,8,9}.
+ */
+public class TestSource extends RichParallelSourceFunction<Integer>
+ implements CheckpointedFunction {
+
+ private static volatile boolean hasThrown = false;
+
+ private ListState<Integer> currentIdxState;
+
+ private Integer currentIdx;
+
+ private Integer mod, numPartitions, numElementsPerPartition;
+
+ private transient volatile boolean running = true;
+
+ public TestSource(int numElementsPerPartition) {
+ this.numElementsPerPartition = numElementsPerPartition;
+ }
+
+ @Override
+ public void open(Configuration parameters) {
+ this.mod = getRuntimeContext().getIndexOfThisSubtask();
+ this.numPartitions = getRuntimeContext().getNumberOfParallelSubtasks();
+ running = true;
+ }
+
+ @Override
+ public void snapshotState(FunctionSnapshotContext functionSnapshotContext)
throws Exception {
+ this.currentIdxState.clear();
+ this.currentIdxState.add(currentIdx);
+ }
+
+ @Override
+ public void initializeState(FunctionInitializationContext
functionInitializationContext)
+ throws Exception {
+ currentIdxState =
+ functionInitializationContext
+ .getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "currentIdx",
BasicTypeInfo.INT_TYPE_INFO));
+ Iterator<Integer> iterator = currentIdxState.get().iterator();
+ currentIdx = 0;
+ if (iterator.hasNext()) {
+ currentIdx = iterator.next();
+ }
+ }
+
+ @Override
+ public void run(SourceContext<Integer> sourceContext) throws Exception {
+ while (running && currentIdx < numElementsPerPartition) {
+ synchronized (sourceContext.getCheckpointLock()) {
+ sourceContext.collect(currentIdx * numPartitions + mod);
+ currentIdx++;
+ }
+ Thread.sleep(1);
+ if (currentIdx == numElementsPerPartition / 2 && (!hasThrown)) {
+ hasThrown = true;
+ throw new RuntimeException("Failing source");
+ }
+ }
+ }
+
+ @Override
+ public void cancel() {
+ running = false;
+ }
+}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorTest.java
new file mode 100644
index 0000000..2457471
--- /dev/null
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorTest.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.MultipleInputStreamTask;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
+import
org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.List;
+
+/** Tests the {@link BroadcastVariableReceiverOperator}. */
+public class BroadcastVariableReceiverOperatorTest {
+
+ private static final String[] BROADCAST_NAMES = new String[] {"source1",
"source2"};
+
+ private static final TypeInformation<?>[] TYPE_INFORMATIONS =
+ new TypeInformation[] {BasicTypeInfo.INT_TYPE_INFO,
BasicTypeInfo.INT_TYPE_INFO};
+
+ @Test
+ public void testCacheStreamOperator() throws Exception {
+ OperatorID operatorId = new OperatorID();
+
+ try (StreamTaskMailboxTestHarness<Integer> harness =
+ new StreamTaskMailboxTestHarnessBuilder<>(
+ MultipleInputStreamTask::new,
BasicTypeInfo.INT_TYPE_INFO)
+ .addInput(BasicTypeInfo.INT_TYPE_INFO)
+ .addInput(BasicTypeInfo.INT_TYPE_INFO)
+ .setupOutputForSingletonOperatorChain(
+ new BroadcastVariableReceiverOperatorFactory<>(
+ BROADCAST_NAMES, TYPE_INFORMATIONS),
+ operatorId)
+ .build()) {
+ harness.processElement(new StreamRecord<>(1, 2), 0);
+ harness.processElement(new StreamRecord<>(2, 3), 0);
+ harness.processElement(new StreamRecord<>(3, 2), 1);
+ harness.processElement(new StreamRecord<>(4, 2), 1);
+ harness.processElement(new StreamRecord<>(5, 3), 1);
+ boolean cacheReady1 =
BroadcastContext.isCacheFinished(BROADCAST_NAMES[0] + "-" + 0);
+ boolean cacheReady2 =
BroadcastContext.isCacheFinished(BROADCAST_NAMES[1] + "-" + 0);
+ // check broadcast inputs before task finishes.
+ Assert.assertFalse(cacheReady1 || cacheReady2);
+
+ harness.waitForTaskCompletion();
+ List<?> cache1 =
BroadcastContext.getBroadcastVariable(BROADCAST_NAMES[0] + "-" + 0);
+ List<?> cache2 =
BroadcastContext.getBroadcastVariable(BROADCAST_NAMES[1] + "-" + 0);
+ // check broadcast inputs after task finishes.
+ compareLists(Arrays.asList(1, 2), cache1);
+ compareLists(Arrays.asList(3, 4, 5), cache2);
+ }
+ }
+
+ public static void compareLists(List<Integer> expected, List<?> actual) {
+ int[] actualInts =
+ actual.stream().map(x -> (Integer)
x).mapToInt(Integer::intValue).toArray();
+ Arrays.sort(actualInts);
+ int[] expectedInts =
expected.stream().mapToInt(Integer::intValue).toArray();
+ Arrays.sort(expectedInts);
+ Assert.assertArrayEquals(expectedInts, actualInts);
+ }
+}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapperOperatorFactory.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapperOperatorFactory.java
new file mode 100644
index 0000000..40a56b0
--- /dev/null
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapperOperatorFactory.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+
+/** Factory class for {@link AbstractBroadcastWrapperOperator}. */
+class BroadcastWrapperOperatorFactory<OUT> extends
AbstractStreamOperatorFactory<OUT> {
+
+ private final StreamOperatorFactory<OUT> operatorFactory;
+
+ private final BroadcastWrapper<OUT> wrapper;
+
+ public BroadcastWrapperOperatorFactory(
+ StreamOperatorFactory<OUT> operatorFactory, BroadcastWrapper<OUT>
wrapper) {
+ this.operatorFactory = operatorFactory;
+ this.wrapper = wrapper;
+ }
+
+ @Override
+ public <T extends StreamOperator<OUT>> T createStreamOperator(
+ StreamOperatorParameters<OUT> parameters) {
+ return (T) wrapper.wrap(parameters, operatorFactory);
+ }
+
+ @Override
+ public Class<? extends StreamOperator> getStreamOperatorClass(ClassLoader
classLoader) {
+ return AbstractBroadcastWrapperOperator.class;
+ }
+}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperatorTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperatorTest.java
new file mode 100644
index 0000000..7d64746
--- /dev/null
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperatorTest.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.iteration.config.IterationOptions;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
+import
org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+import org.apache.flink.streaming.util.TestHarnessUtil;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Queue;
+import java.util.concurrent.ConcurrentLinkedQueue;
+
+/** Tests the {@link OneInputBroadcastWrapperOperator}. */
+public class OneInputBroadcastWrapperOperatorTest {
+
+ @Rule public TemporaryFolder tempFolder = new TemporaryFolder();
+
+ private static final String[] BROADCAST_NAMES = new String[] {"source1",
"source2"};
+
+ private static final TypeInformation<?>[] TYPE_INFORMATIONS =
+ new TypeInformation[] {BasicTypeInfo.INT_TYPE_INFO,
BasicTypeInfo.INT_TYPE_INFO};
+
+ private static final List<Integer> SOURCE_1 = Collections.singletonList(1);
+
+ private static final List<Integer> SOURCE_2 = Arrays.asList(1, 2, 3);
+
+ private static class MyRichFunction extends AbstractRichFunction {}
+
+ @Test
+ public void testProcessElements() throws Exception {
+ OneInputStreamOperator<Integer, Integer> inputOp =
+ new TestOneInputOp(
+ new MyRichFunction(), BROADCAST_NAMES,
Arrays.asList(SOURCE_1, SOURCE_2));
+ BroadcastWrapper<Integer> broadcastWrapper =
+ new BroadcastWrapper<>(BROADCAST_NAMES, TYPE_INFORMATIONS);
+ BroadcastWrapperOperatorFactory<Integer> wrapperFactory =
+ new BroadcastWrapperOperatorFactory<>(
+ SimpleOperatorFactory.of(inputOp), broadcastWrapper);
+ OperatorID operatorId = new OperatorID();
+
+ try (StreamTaskMailboxTestHarness<Integer> harness =
+ new StreamTaskMailboxTestHarnessBuilder<>(
+ OneInputStreamTask::new,
BasicTypeInfo.INT_TYPE_INFO)
+ .addInput(BasicTypeInfo.INT_TYPE_INFO)
+ .setupOutputForSingletonOperatorChain(wrapperFactory,
operatorId)
+ .buildUnrestored()) {
+ harness.getStreamTask()
+ .getEnvironment()
+ .getTaskManagerInfo()
+ .getConfiguration()
+ .set(
+ IterationOptions.DATA_CACHE_PATH,
+ "file://" +
tempFolder.newFolder().getAbsolutePath());
+ harness.getStreamTask().restore();
+
+ BroadcastContext.putBroadcastVariable(
+ BROADCAST_NAMES[0] + "-" + 0, Tuple2.of(true, SOURCE_1));
+ BroadcastContext.putBroadcastVariable(
+ BROADCAST_NAMES[1] + "-" + 0, Tuple2.of(true, SOURCE_2));
+
+ Queue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
+ for (int i = 0; i < 5; ++i) {
+ harness.processElement(new StreamRecord<>(i, 1000), 0);
+ expectedOutput.add(new StreamRecord<>(i, 1000));
+ }
+ TestHarnessUtil.assertOutputEquals(
+ "Output was not correct.", expectedOutput,
harness.getOutput());
+ }
+ }
+}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/TestOneInputOp.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/TestOneInputOp.java
new file mode 100644
index 0000000..69ccbd4
--- /dev/null
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/TestOneInputOp.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.functions.RichFunction;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import java.util.List;
+
+/** Utility class used for testing {@link OneInputBroadcastWrapperOperator}. */
+public class TestOneInputOp extends AbstractUdfStreamOperator<Integer,
RichFunction>
+ implements OneInputStreamOperator<Integer, Integer> {
+
+ private final String[] broadcastNames;
+
+ private final List<List<Integer>> expectedBroadcastInputs;
+
+ public TestOneInputOp(
+ RichFunction func,
+ String[] broadcastNames,
+ List<List<Integer>> expectedBroadcastInputs) {
+ super(func);
+ this.broadcastNames = broadcastNames;
+ this.expectedBroadcastInputs = expectedBroadcastInputs;
+ }
+
+ @Override
+ public void processElement(StreamRecord<Integer> streamRecord) {
+ for (int i = 0; i < broadcastNames.length; i++) {
+ BroadcastVariableReceiverOperatorTest.compareLists(
+ expectedBroadcastInputs.get(i),
+
userFunction.getRuntimeContext().getBroadcastVariable(broadcastNames[i]));
+ }
+ output.collect(streamRecord);
+ }
+}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/TestTwoInputOp.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/TestTwoInputOp.java
new file mode 100644
index 0000000..bf7b102
--- /dev/null
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/TestTwoInputOp.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.functions.RichFunction;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import java.util.List;
+
+/** Utility class used for testing {@link TwoInputBroadcastWrapperOperator}. */
+public class TestTwoInputOp extends AbstractUdfStreamOperator<Integer,
RichFunction>
+ implements TwoInputStreamOperator<Integer, Integer, Integer> {
+
+ private final String[] broadcastNames;
+
+ private final List<List<Integer>> expectedBroadcastInputs;
+
+ public TestTwoInputOp(
+ RichFunction func, String[] broadcastNames, List<List<Integer>>
expectedSizes) {
+ super(func);
+ this.broadcastNames = broadcastNames;
+ this.expectedBroadcastInputs = expectedSizes;
+ }
+
+ @Override
+ public void processElement1(StreamRecord<Integer> streamRecord) {
+ for (int i = 0; i < broadcastNames.length; i++) {
+ BroadcastVariableReceiverOperatorTest.compareLists(
+ expectedBroadcastInputs.get(i),
+
userFunction.getRuntimeContext().getBroadcastVariable(broadcastNames[i]));
+ }
+ output.collect(streamRecord);
+ }
+
+ @Override
+ public void processElement2(StreamRecord<Integer> streamRecord) {
+ for (int i = 0; i < broadcastNames.length; i++) {
+ BroadcastVariableReceiverOperatorTest.compareLists(
+ expectedBroadcastInputs.get(i),
+
userFunction.getRuntimeContext().getBroadcastVariable(broadcastNames[i]));
+ }
+ output.collect(streamRecord);
+ }
+}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperatorTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperatorTest.java
new file mode 100644
index 0000000..73bcbbc
--- /dev/null
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperatorTest.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.iteration.config.IterationOptions;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
+import
org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+import org.apache.flink.streaming.runtime.tasks.TwoInputStreamTask;
+import org.apache.flink.streaming.util.TestHarnessUtil;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Queue;
+import java.util.concurrent.ConcurrentLinkedQueue;
+
+/** Tests the ${@link TwoInputBroadcastWrapperOperator}. */
+public class TwoInputBroadcastWrapperOperatorTest {
+
+ @Rule public TemporaryFolder tempFolder = new TemporaryFolder();
+
+ private static final String[] BROADCAST_NAMES = new String[] {"source1",
"source2"};
+
+ private static final TypeInformation<?>[] TYPE_INFORMATIONS =
+ new TypeInformation[] {BasicTypeInfo.INT_TYPE_INFO,
BasicTypeInfo.INT_TYPE_INFO};
+
+ private static final List<Integer> SOURCE_1 = Collections.singletonList(1);
+
+ private static final List<Integer> SOURCE_2 = Arrays.asList(1, 2, 3);
+
+ private static class MyRichFunction extends AbstractRichFunction {}
+
+ @Test
+ public void testProcessElements() throws Exception {
+ TwoInputStreamOperator<Integer, Integer, Integer> inputOp =
+ new TestTwoInputOp(
+ new MyRichFunction(), BROADCAST_NAMES,
Arrays.asList(SOURCE_1, SOURCE_2));
+ BroadcastWrapper<Integer> broadcastWrapper =
+ new BroadcastWrapper<>(BROADCAST_NAMES, TYPE_INFORMATIONS);
+ BroadcastWrapperOperatorFactory<Integer> wrapperFactory =
+ new BroadcastWrapperOperatorFactory<>(
+ SimpleOperatorFactory.of(inputOp), broadcastWrapper);
+ OperatorID operatorId = new OperatorID();
+
+ try (StreamTaskMailboxTestHarness<Integer> harness =
+ new StreamTaskMailboxTestHarnessBuilder<>(
+ TwoInputStreamTask::new,
BasicTypeInfo.INT_TYPE_INFO)
+ .addInput(BasicTypeInfo.INT_TYPE_INFO)
+ .addInput(BasicTypeInfo.INT_TYPE_INFO)
+ .setupOutputForSingletonOperatorChain(wrapperFactory,
operatorId)
+ .buildUnrestored()) {
+ harness.getStreamTask()
+ .getEnvironment()
+ .getTaskManagerInfo()
+ .getConfiguration()
+ .set(
+ IterationOptions.DATA_CACHE_PATH,
+ "file://" +
tempFolder.newFolder().getAbsolutePath());
+ harness.getStreamTask().restore();
+ BroadcastContext.putBroadcastVariable(
+ BROADCAST_NAMES[0] + "-" + 0, Tuple2.of(true, SOURCE_1));
+ BroadcastContext.putBroadcastVariable(
+ BROADCAST_NAMES[1] + "-" + 0, Tuple2.of(true, SOURCE_2));
+
+ Queue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
+ for (int i = 0; i < 5; ++i) {
+ harness.processElement(new StreamRecord<>(i, 1000), 0);
+ harness.processElement(new StreamRecord<>(i, 1000), 1);
+ expectedOutput.add(new StreamRecord<>(i, 1000));
+ expectedOutput.add(new StreamRecord<>(i, 1000));
+ }
+ TestHarnessUtil.assertOutputEquals(
+ "Output was not correct", expectedOutput,
harness.getOutput());
+ }
+ }
+}
diff --git a/flink-ml-lib/src/test/resources/log4j2-test.properties
b/flink-ml-lib/src/test/resources/log4j2-test.properties
new file mode 100644
index 0000000..835c2ec
--- /dev/null
+++ b/flink-ml-lib/src/test/resources/log4j2-test.properties
@@ -0,0 +1,28 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Set root logger level to OFF to not flood build logs
+# set manually to INFO for debugging purposes
+rootLogger.level = OFF
+rootLogger.appenderRef.test.ref = TestLogger
+
+appender.testlogger.name = TestLogger
+appender.testlogger.type = CONSOLE
+appender.testlogger.target = SYSTEM_ERR
+appender.testlogger.layout.type = PatternLayout
+appender.testlogger.layout.pattern = %-4r [%t] %-5p %c %x - %m%n