This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 8269bd9  [FLINK-24279] support withBroadcast in DataStream
8269bd9 is described below

commit 8269bd9fdf3e5744b2d635697db5c705b9e598f5
Author: zhangzp <[email protected]>
AuthorDate: Fri Nov 5 15:33:04 2021 +0800

    [FLINK-24279] support withBroadcast in DataStream
    
    This closes #18.
---
 .../datacache/nonkeyed/DataCacheReader.java        |   6 -
 .../datacache/nonkeyed/DataCacheWriter.java        |  14 +
 flink-ml-lib/pom.xml                               |  39 ++
 .../ml/common/broadcast/BroadcastContext.java      | 113 ++++
 .../BroadcastStreamingRuntimeContext.java          |  96 ++++
 .../flink/ml/common/broadcast/BroadcastUtils.java  | 194 +++++++
 .../operator/AbstractBroadcastWrapperOperator.java | 613 +++++++++++++++++++++
 .../BroadcastVariableReceiverOperator.java         | 154 ++++++
 .../BroadcastVariableReceiverOperatorFactory.java  |  54 ++
 .../broadcast/operator/BroadcastWrapper.java       |  96 ++++
 .../operator/OneInputBroadcastWrapperOperator.java |  82 +++
 .../operator/TwoInputBroadcastWrapperOperator.java | 114 ++++
 .../ml/common/broadcast/typeinfo/CacheElement.java |  79 +++
 .../broadcast/typeinfo/CacheElementSerializer.java | 208 +++++++
 .../broadcast/typeinfo/CacheElementTypeInfo.java   | 103 ++++
 .../ml/common/broadcast/BroadcastUtilsTest.java    | 176 ++++++
 .../apache/flink/ml/common/broadcast/TestSink.java |  86 +++
 .../flink/ml/common/broadcast/TestSource.java      | 105 ++++
 .../BroadcastVariableReceiverOperatorTest.java     |  85 +++
 .../operator/BroadcastWrapperOperatorFactory.java  |  49 ++
 .../OneInputBroadcastWrapperOperatorTest.java      | 103 ++++
 .../common/broadcast/operator/TestOneInputOp.java  |  54 ++
 .../common/broadcast/operator/TestTwoInputOp.java  |  62 +++
 .../TwoInputBroadcastWrapperOperatorTest.java      | 105 ++++
 .../src/test/resources/log4j2-test.properties      |  28 +
 25 files changed, 2812 insertions(+), 6 deletions(-)

diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheReader.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheReader.java
index 1840104..fc47e9f 100644
--- 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheReader.java
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheReader.java
@@ -30,8 +30,6 @@ import java.io.IOException;
 import java.util.Iterator;
 import java.util.List;
 
-import static org.apache.flink.util.Preconditions.checkArgument;
-
 /** Reads the cached data from a list of paths. */
 public class DataCacheReader<T> implements Iterator<T> {
 
@@ -47,10 +45,6 @@ public class DataCacheReader<T> implements Iterator<T> {
             TypeSerializer<T> serializer, FileSystem fileSystem, List<Segment> 
segments)
             throws IOException {
 
-        for (Segment segment : segments) {
-            checkArgument(segment.getCount() > 0, "Do not support empty 
segment");
-        }
-
         this.serializer = serializer;
         this.fileSystem = fileSystem;
         this.segments = segments;
diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java
index 35256eb..ec8d73d 100644
--- 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java
@@ -26,6 +26,8 @@ import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.util.function.SupplierWithException;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
@@ -49,11 +51,23 @@ public class DataCacheWriter<T> {
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator)
             throws IOException {
+        this(serializer, fileSystem, pathGenerator, null);
+    }
+
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            @Nullable List<Segment> writtenSegments)
+            throws IOException {
         this.serializer = serializer;
         this.fileSystem = fileSystem;
         this.pathGenerator = pathGenerator;
 
         this.finishSegments = new ArrayList<>();
+        if (null != writtenSegments) {
+            finishSegments.addAll(writtenSegments);
+        }
 
         this.currentSegment = new SegmentWriter(pathGenerator.get());
     }
diff --git a/flink-ml-lib/pom.xml b/flink-ml-lib/pom.xml
index 17d3e89..bee7f25 100644
--- a/flink-ml-lib/pom.xml
+++ b/flink-ml-lib/pom.xml
@@ -37,6 +37,12 @@ under the License.
       <scope>provided</scope>
     </dependency>
     <dependency>
+    <groupId>org.apache.flink</groupId>
+    <artifactId>flink-ml-iteration</artifactId>
+    <version>${project.version}</version>
+    <scope>provided</scope>
+  </dependency>
+    <dependency>
       <groupId>org.apache.flink</groupId>
       <artifactId>flink-table-api-java</artifactId>
       <version>${flink.version}</version>
@@ -65,6 +71,39 @@ under the License.
       <artifactId>core</artifactId>
       <version>1.1.2</version>
     </dependency>
+    <dependency>
+      <groupId>org.apache.flink</groupId>
+      <artifactId>flink-clients_${scala.binary.version}</artifactId>
+      <version>${flink.version}</version>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.flink</groupId>
+      <artifactId>flink-streaming-java_${scala.binary.version}</artifactId>
+      <version>${flink.version}</version>
+      <scope>test</scope>
+      <type>test-jar</type>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.flink</groupId>
+      <artifactId>flink-runtime</artifactId>
+      <version>${flink.version}</version>
+      <scope>test</scope>
+      <type>test-jar</type>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.flink</groupId>
+      <artifactId>flink-test-utils-junit</artifactId>
+      <version>${flink.version}</version>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.mockito</groupId>
+      <artifactId>mockito-core</artifactId>
+      <version>2.21.0</version>
+      <type>jar</type>
+      <scope>test</scope>
+    </dependency>
   </dependencies>
 
   <build>
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
new file mode 100644
index 0000000..ac6950b
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.operators.MailboxExecutor;
+import org.apache.flink.api.java.tuple.Tuple2;
+
+import javax.annotation.Nullable;
+
+import java.util.List;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * Context to hold the broadcast variables and provides some utility function 
for accessing
+ * broadcast variables.
+ */
+public class BroadcastContext {
+
+    /**
+     * stores broadcast data streams in a map. The key is 
broadcastName-partitionId and the value is
+     * {@link BroadcastItem}.
+     */
+    private static final ConcurrentHashMap<String, BroadcastItem> 
BROADCAST_VARIABLES =
+            new ConcurrentHashMap<>();
+
+    @VisibleForTesting
+    public static void putBroadcastVariable(String key, Tuple2<Boolean, 
List<?>> variable) {
+        BROADCAST_VARIABLES.compute(
+                key,
+                (k, v) ->
+                        null == v
+                                ? new BroadcastItem(variable.f0, variable.f1, 
null)
+                                : new BroadcastItem(variable.f0, variable.f1, 
v.mailboxExecutor));
+    }
+
+    @VisibleForTesting
+    public static void putMailBoxExecutor(String key, MailboxExecutor 
mailboxExecutor) {
+        BROADCAST_VARIABLES.compute(
+                key,
+                (k, v) ->
+                        null == v
+                                ? new BroadcastItem(false, null, 
mailboxExecutor)
+                                : new BroadcastItem(v.cacheReady, v.cacheList, 
mailboxExecutor));
+    }
+
+    @VisibleForTesting
+    @SuppressWarnings({"unchecked"})
+    public static <T> List<T> getBroadcastVariable(String key) {
+        return (List<T>) BROADCAST_VARIABLES.get(key).cacheList;
+    }
+
+    @VisibleForTesting
+    public static void remove(String key) {
+        BROADCAST_VARIABLES.remove(key);
+    }
+
+    @VisibleForTesting
+    public static void markCacheFinished(String key) {
+        BROADCAST_VARIABLES.computeIfPresent(
+                key,
+                (k, v) -> {
+                    // sends an dummy email to avoid possible stuck.
+                    if (null != v.mailboxExecutor) {
+                        v.mailboxExecutor.execute(() -> {}, "empty mail");
+                    }
+                    return new BroadcastItem(true, v.cacheList, 
v.mailboxExecutor);
+                });
+    }
+
+    @VisibleForTesting
+    public static boolean isCacheFinished(String key) {
+        return BROADCAST_VARIABLES.get(key).cacheReady;
+    }
+
+    /** Utility class to organize broadcast variables. */
+    private static class BroadcastItem {
+
+        /** whether this broadcast variable is ready to be consumed. */
+        private boolean cacheReady;
+
+        /** the cached list. */
+        private List<?> cacheList;
+
+        /** the mailboxExecutor of the consumer, used to avoid the possible 
stuck of consumer. */
+        private MailboxExecutor mailboxExecutor;
+
+        BroadcastItem(
+                boolean cacheReady,
+                @Nullable List<?> cacheList,
+                @Nullable MailboxExecutor mailboxExecutor) {
+            this.cacheReady = cacheReady;
+            this.cacheList = cacheList;
+            this.mailboxExecutor = mailboxExecutor;
+        }
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastStreamingRuntimeContext.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastStreamingRuntimeContext.java
new file mode 100644
index 0000000..809fdac
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastStreamingRuntimeContext.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.accumulators.Accumulator;
+import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
+import org.apache.flink.api.common.state.KeyedStateStore;
+import org.apache.flink.metrics.groups.OperatorMetricGroup;
+import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+
+import javax.annotation.Nullable;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An subclass of {@link StreamingRuntimeContext} that provides accessibility 
of broadcast
+ * variables.
+ */
+public class BroadcastStreamingRuntimeContext extends StreamingRuntimeContext {
+
+    Map<String, List<?>> broadcastVariables = new HashMap<>();
+
+    public BroadcastStreamingRuntimeContext(
+            Environment env,
+            Map<String, Accumulator<?, ?>> accumulators,
+            OperatorMetricGroup operatorMetricGroup,
+            OperatorID operatorID,
+            ProcessingTimeService processingTimeService,
+            @Nullable KeyedStateStore keyedStateStore,
+            ExternalResourceInfoProvider externalResourceInfoProvider) {
+        super(
+                env,
+                accumulators,
+                operatorMetricGroup,
+                operatorID,
+                processingTimeService,
+                keyedStateStore,
+                externalResourceInfoProvider);
+    }
+
+    @Override
+    public boolean hasBroadcastVariable(String name) {
+        return broadcastVariables.containsKey(name);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public <RT> List<RT> getBroadcastVariable(String name) {
+        if (broadcastVariables.containsKey(name)) {
+            return (List<RT>) broadcastVariables.get(name);
+        } else {
+            throw new RuntimeException(
+                    "Cannot get broadcast variables before processing 
elements.");
+        }
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public <T, C> C getBroadcastVariableWithInitializer(
+            String name, BroadcastVariableInitializer<T, C> initializer) {
+        if (broadcastVariables.containsKey(name)) {
+            return initializer.initializeBroadcastVariable((List<T>) 
broadcastVariables.get(name));
+        } else {
+            throw new RuntimeException(
+                    "Cannot get broadcast variables before processing 
elements.");
+        }
+    }
+
+    @Internal
+    public void setBroadcastVariable(String name, List<?> broadcastVariable) {
+        broadcastVariables.put(name, broadcastVariable);
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
new file mode 100644
index 0000000..bb12647
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.compile.DraftExecutionEnvironment;
+import 
org.apache.flink.ml.common.broadcast.operator.BroadcastVariableReceiverOperatorFactory;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.ChainingStrategy;
+import 
org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.streaming.api.transformations.PhysicalTransformation;
+import org.apache.flink.util.AbstractID;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.function.Function;
+
+/** Utility class to support withBroadcast in DataStream. */
+public class BroadcastUtils {
+    /**
+     * supports withBroadcastStream in DataStream API. Broadcast data streams 
are available at all
+     * parallel instances of an operator that extends {@code
+     * org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator<OUT, 
? extends
+     * org.apache.flink.api.common.functions.RichFunction>}. Users can access 
the broadcast
+     * variables by {@code 
RichFunction.getRuntimeContext().getBroadcastVariable(...)} or {@code
+     * RichFunction.getRuntimeContext().hasBroadcastVariable(...)} or {@code
+     * 
RichFunction.getRuntimeContext().getBroadcastVariableWithInitializer(...)}.
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first 
and further consumed by
+     * non-broadcast inputs. For now the non-broadcast input are cached by 
default to avoid the
+     * possible deadlocks.
+     *
+     * @param inputList non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is 
the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can 
access the broadcast
+     *     data streams and produce the output data stream. Note that users 
can add only one
+     *     operator in this function, otherwise it raises an exception.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(
+            List<DataStream<?>> inputList,
+            Map<String, DataStream<?>> bcStreams,
+            Function<List<DataStream<?>>, DataStream<OUT>> 
userDefinedFunction) {
+        Preconditions.checkArgument(inputList.size() > 0);
+
+        StreamExecutionEnvironment env = 
inputList.get(0).getExecutionEnvironment();
+        String[] broadcastNames = new String[bcStreams.size()];
+        DataStream<?>[] broadcastInputs = new DataStream[bcStreams.size()];
+        TypeInformation<?>[] broadcastInTypes = new 
TypeInformation[bcStreams.size()];
+        int idx = 0;
+        final String broadcastId = new AbstractID().toHexString();
+        for (String name : bcStreams.keySet()) {
+            broadcastNames[idx] = broadcastId + "-" + name;
+            broadcastInputs[idx] = bcStreams.get(name);
+            broadcastInTypes[idx] = broadcastInputs[idx].getType();
+            idx++;
+        }
+
+        DataStream<OUT> resultStream =
+                getResultStream(env, inputList, broadcastNames, 
userDefinedFunction);
+        TypeInformation<OUT> outType = resultStream.getType();
+        final String coLocationKey = "broadcast-co-location-" + 
UUID.randomUUID();
+        DataStream<OUT> cachedBroadcastInputs =
+                cacheBroadcastVariables(
+                        env,
+                        broadcastNames,
+                        broadcastInputs,
+                        broadcastInTypes,
+                        resultStream.getParallelism(),
+                        outType);
+
+        boolean canCoLocate =
+                cachedBroadcastInputs.getTransformation() instanceof 
PhysicalTransformation
+                        && resultStream.getTransformation() instanceof 
PhysicalTransformation;
+        if (canCoLocate) {
+            ((PhysicalTransformation<?>) 
cachedBroadcastInputs.getTransformation())
+                    .setChainingStrategy(ChainingStrategy.HEAD);
+            ((PhysicalTransformation<?>) resultStream.getTransformation())
+                    .setChainingStrategy(ChainingStrategy.HEAD);
+        } else {
+            throw new UnsupportedOperationException(
+                    "cannot set chaining strategy on "
+                            + cachedBroadcastInputs.getTransformation()
+                            + " and "
+                            + resultStream.getTransformation()
+                            + ".");
+        }
+        
cachedBroadcastInputs.getTransformation().setCoLocationGroupKey(coLocationKey);
+        resultStream.getTransformation().setCoLocationGroupKey(coLocationKey);
+
+        return cachedBroadcastInputs.union(resultStream);
+    }
+
+    /**
+     * caches all broadcast iput data streams in static variables and returns 
the result multi-input
+     * stream operator. The result multi-input stream operator emits nothing 
and the only
+     * functionality of this operator is to cache all the input records in 
${@link
+     * BroadcastContext}.
+     *
+     * @param env execution environment.
+     * @param broadcastInputNames names of the broadcast input data streams.
+     * @param broadcastInputs list of the broadcast data streams.
+     * @param broadcastInTypes output types of the broadcast input data 
streams.
+     * @param parallelism parallelism.
+     * @param outType output type.
+     * @param <OUT> output type.
+     * @return the result multi-input stream operator.
+     */
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            String[] broadcastInputNames,
+            DataStream<?>[] broadcastInputs,
+            TypeInformation<?>[] broadcastInTypes,
+            int parallelism,
+            TypeInformation<OUT> outType) {
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<>(
+                        "broadcastInputs",
+                        new BroadcastVariableReceiverOperatorFactory<>(
+                                broadcastInputNames, broadcastInTypes),
+                        outType,
+                        parallelism);
+        for (DataStream<?> dataStream : broadcastInputs) {
+            
transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    /**
+     * uses {@link DraftExecutionEnvironment} to execute the 
userDefinedFunction and returns the
+     * resultStream.
+     *
+     * @param env execution environment.
+     * @param inputList non-broadcast input list.
+     * @param broadcastStreamNames names of the broadcast data streams.
+     * @param graphBuilder user-defined logic.
+     * @param <OUT> output type of the result stream.
+     * @return the result stream by applying user-defined logic on the input 
list.
+     */
+    private static <OUT> DataStream<OUT> getResultStream(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation<?>[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            inTypes[i] = inputList.get(i).getType();
+        }
+        // do not block all non-broadcast input edges by default.
+        boolean[] isBlocked = new boolean[inputList.size()];
+        Arrays.fill(isBlocked, false);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, 
inTypes, isBlocked));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (DataStream<?> dataStream : inputList) {
+            draftSources.add(draftEnv.addDraftSource(dataStream, 
dataStream.getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+        Preconditions.checkState(
+                draftEnv.getStreamGraph(false).getStreamNodes().size() == 1 + 
inputList.size(),
+                "cannot add more than one operator in withBroadcastStream's 
lambda function.");
+        draftEnv.copyToActualEnvironment();
+        return draftEnv.getActualStream(draftOutStream.getId());
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java
new file mode 100644
index 0000000..200ad94
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java
@@ -0,0 +1,613 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.functions.RichFunction;
+import org.apache.flink.api.common.operators.MailboxExecutor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.ManagedMemoryUseCase;
+import org.apache.flink.iteration.datacache.nonkeyed.DataCacheReader;
+import org.apache.flink.iteration.datacache.nonkeyed.DataCacheSnapshot;
+import org.apache.flink.iteration.datacache.nonkeyed.DataCacheWriter;
+import org.apache.flink.iteration.datacache.nonkeyed.Segment;
+import org.apache.flink.iteration.operator.OperatorUtils;
+import org.apache.flink.iteration.proxy.state.ProxyStreamOperatorStateContext;
+import org.apache.flink.metrics.groups.OperatorMetricGroup;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.ml.common.broadcast.BroadcastStreamingRuntimeContext;
+import org.apache.flink.ml.common.broadcast.typeinfo.CacheElement;
+import org.apache.flink.ml.common.broadcast.typeinfo.CacheElementTypeInfo;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.metrics.groups.InternalOperatorIOMetricGroup;
+import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.util.NonClosingInputStreamDecorator;
+import org.apache.flink.runtime.util.NonClosingOutpusStreamDecorator;
+import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
+import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
+import org.apache.flink.streaming.api.operators.Output;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.StreamOperatorStateContext;
+import org.apache.flink.streaming.api.operators.StreamOperatorStateHandler;
+import 
org.apache.flink.streaming.api.operators.StreamOperatorStateHandler.CheckpointedStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.streaming.runtime.tasks.mailbox.TaskMailbox;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.ThrowingConsumer;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.InputStream;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+
+/** Base class for the broadcast wrapper operators. */
+public abstract class AbstractBroadcastWrapperOperator<T, S extends 
StreamOperator<T>>
+        implements StreamOperator<T>, 
StreamOperatorStateHandler.CheckpointedStreamOperator {
+
+    private static final Logger LOG =
+            LoggerFactory.getLogger(AbstractBroadcastWrapperOperator.class);
+
+    protected final StreamOperatorParameters<T> parameters;
+
+    protected final StreamConfig streamConfig;
+
+    protected final StreamTask<?, ?> containingTask;
+
+    protected final Output<StreamRecord<T>> output;
+
+    protected final StreamOperatorFactory<T> operatorFactory;
+
+    protected final OperatorMetricGroup metrics;
+
+    protected final S wrappedOperator;
+
+    protected transient StreamOperatorStateHandler stateHandler;
+
+    protected transient InternalTimeServiceManager<?> timeServiceManager;
+
+    protected final MailboxExecutor mailboxExecutor;
+
+    /** variables specific for withBroadcast functionality. */
+    protected final String[] broadcastStreamNames;
+
+    /**
+     * whether each input is blocked. Inputs with broadcast variables can only 
process their input
+     * records after broadcast variables are ready. One input is non-blocked 
if it can consume its
+     * inputs (by caching) when broadcast variables are not ready. Otherwise 
it has to block the
+     * processing and wait until the broadcast variables are ready to be 
accessed.
+     */
+    protected final boolean[] isBlocked;
+
+    /** type information of each input. */
+    protected final TypeInformation<?>[] inTypes;
+
+    /** whether all broadcast variables of this operator are ready. */
+    protected boolean broadcastVariablesReady;
+
+    /** index of this subtask. */
+    protected final transient int indexOfSubtask;
+
+    /** number of the inputs of this operator. */
+    protected final int numInputs;
+
+    /** runtimeContext of the rich function in wrapped operator. */
+    BroadcastStreamingRuntimeContext wrappedOperatorRuntimeContext;
+
+    /**
+     * path of the file used to stored the cached records. It could be local 
file system or remote
+     * file system.
+     */
+    private Path basePath;
+
+    /** DataCacheWriter for each input. */
+    @SuppressWarnings("rawtypes")
+    protected DataCacheWriter[] dataCacheWriters;
+
+    /** whether each input has pending elements. */
+    protected boolean[] hasPendingElements;
+
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    AbstractBroadcastWrapperOperator(
+            StreamOperatorParameters<T> parameters,
+            StreamOperatorFactory<T> operatorFactory,
+            String[] broadcastStreamNames,
+            TypeInformation<?>[] inTypes,
+            boolean[] isBlocked) {
+        this.parameters = Objects.requireNonNull(parameters);
+        this.streamConfig = 
Objects.requireNonNull(parameters.getStreamConfig());
+        this.containingTask = 
Objects.requireNonNull(parameters.getContainingTask());
+        this.output = Objects.requireNonNull(parameters.getOutput());
+        this.operatorFactory = Objects.requireNonNull(operatorFactory);
+        this.metrics = 
createOperatorMetricGroup(containingTask.getEnvironment(), streamConfig);
+        this.wrappedOperator =
+                (S)
+                        StreamOperatorFactoryUtil.<T, S>createOperator(
+                                        operatorFactory,
+                                        (StreamTask) containingTask,
+                                        streamConfig,
+                                        output,
+                                        
parameters.getOperatorEventDispatcher())
+                                .f0;
+
+        boolean hasRichFunction =
+                wrappedOperator instanceof AbstractUdfStreamOperator
+                        && ((AbstractUdfStreamOperator) 
wrappedOperator).getUserFunction()
+                                instanceof RichFunction;
+
+        if (hasRichFunction) {
+            wrappedOperatorRuntimeContext =
+                    new BroadcastStreamingRuntimeContext(
+                            containingTask.getEnvironment(),
+                            
containingTask.getEnvironment().getAccumulatorRegistry().getUserMap(),
+                            wrappedOperator.getMetricGroup(),
+                            wrappedOperator.getOperatorID(),
+                            ((AbstractUdfStreamOperator) wrappedOperator)
+                                    .getProcessingTimeService(),
+                            null,
+                            
containingTask.getEnvironment().getExternalResourceInfoProvider());
+
+            ((RichFunction) ((AbstractUdfStreamOperator) 
wrappedOperator).getUserFunction())
+                    .setRuntimeContext(wrappedOperatorRuntimeContext);
+        } else {
+            throw new RuntimeException(
+                    "The operator is not a instance of "
+                            + AbstractUdfStreamOperator.class.getSimpleName()
+                            + " that contains a "
+                            + RichFunction.class.getSimpleName());
+        }
+
+        this.mailboxExecutor =
+                
containingTask.getMailboxExecutorFactory().createExecutor(TaskMailbox.MIN_PRIORITY);
+        // variables specific for withBroadcast functionality.
+        this.broadcastStreamNames = broadcastStreamNames;
+        this.isBlocked = isBlocked;
+        this.inTypes = inTypes;
+        this.broadcastVariablesReady = false;
+        this.indexOfSubtask = containingTask.getIndexInSubtaskGroup();
+        this.numInputs = inTypes.length;
+
+        // puts in mailboxExecutor
+        for (String name : broadcastStreamNames) {
+            BroadcastContext.putMailBoxExecutor(name + "-" + indexOfSubtask, 
mailboxExecutor);
+        }
+
+        basePath =
+                OperatorUtils.getDataCachePath(
+                        
containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(),
+                        containingTask
+                                .getEnvironment()
+                                .getIOManager()
+                                .getSpillingDirectoriesPaths());
+        dataCacheWriters = new DataCacheWriter[numInputs];
+        hasPendingElements = new boolean[numInputs];
+        Arrays.fill(hasPendingElements, true);
+    }
+
+    /**
+     * checks whether all of broadcast variables are ready. Besides it 
maintains a state
+     * {broadcastVariablesReady} to avoiding invoking {@code 
BroadcastContext.isCacheFinished(...)}
+     * repeatedly. Finally, it sets broadcast variables for 
{wrappedOperatorRuntimeContext} if the
+     * broadcast variables are ready.
+     *
+     * @return true if all broadcast variables are ready, false otherwise.
+     */
+    protected boolean areBroadcastVariablesReady() {
+        if (broadcastVariablesReady) {
+            return true;
+        }
+        for (String name : broadcastStreamNames) {
+            if (!BroadcastContext.isCacheFinished(name + "-" + 
indexOfSubtask)) {
+                return false;
+            } else {
+                String key = name + "-" + indexOfSubtask;
+                String userKey = name.substring(name.indexOf('-') + 1);
+                wrappedOperatorRuntimeContext.setBroadcastVariable(
+                        userKey, BroadcastContext.getBroadcastVariable(key));
+            }
+        }
+        broadcastVariablesReady = true;
+        return true;
+    }
+
+    private OperatorMetricGroup createOperatorMetricGroup(
+            Environment environment, StreamConfig streamConfig) {
+        try {
+            OperatorMetricGroup operatorMetricGroup =
+                    environment
+                            .getMetricGroup()
+                            .getOrAddOperator(
+                                    streamConfig.getOperatorID(), 
streamConfig.getOperatorName());
+            if (streamConfig.isChainEnd()) {
+                ((InternalOperatorIOMetricGroup) 
operatorMetricGroup.getIOMetricGroup())
+                        .reuseOutputMetricsForTask();
+            }
+            return operatorMetricGroup;
+        } catch (Exception e) {
+            LOG.warn("An error occurred while instantiating task metrics.", e);
+            return 
UnregisteredMetricGroups.createUnregisteredOperatorMetricGroup();
+        }
+    }
+
+    /**
+     * extracts common processing logic in subclasses' processing elements.
+     *
+     * @param streamRecord the input record.
+     * @param inputIndex input id, starts from zero.
+     * @param elementConsumer the consumer function of StreamRecord, i.e.,
+     *     operator.processElement(...).
+     * @param watermarkConsumer the consumer function of WaterMark, i.e.,
+     *     operator.processWatermark(...).
+     * @throws Exception possible exception.
+     */
+    @SuppressWarnings({"rawtypes", "unchecked"})
+    protected void processElementX(
+            StreamRecord streamRecord,
+            int inputIndex,
+            ThrowingConsumer<StreamRecord, Exception> elementConsumer,
+            ThrowingConsumer<Watermark, Exception> watermarkConsumer)
+            throws Exception {
+        if (!isBlocked[inputIndex]) {
+            if (areBroadcastVariablesReady()) {
+                if (hasPendingElements[inputIndex]) {
+                    processPendingElementsAndWatermarks(
+                            inputIndex, elementConsumer, watermarkConsumer);
+                    hasPendingElements[inputIndex] = false;
+                }
+                elementConsumer.accept(streamRecord);
+
+            } else {
+                dataCacheWriters[inputIndex].addRecord(
+                        CacheElement.newRecord(streamRecord.getValue()));
+            }
+
+        } else {
+            while (!areBroadcastVariablesReady()) {
+                mailboxExecutor.yield();
+            }
+            elementConsumer.accept(streamRecord);
+        }
+    }
+
+    /**
+     * extracts common processing logic in subclasses' processing watermarks.
+     *
+     * @param watermark the input watermark.
+     * @param inputIndex input id, starts from zero.
+     * @param elementConsumer the consumer function of StreamRecord, i.e.,
+     *     operator.processElement(...).
+     * @param watermarkConsumer the consumer function of WaterMark, i.e.,
+     *     operator.processWatermark(...).
+     * @throws Exception possible exception.
+     */
+    @SuppressWarnings({"rawtypes", "unchecked"})
+    protected void processWatermarkX(
+            Watermark watermark,
+            int inputIndex,
+            ThrowingConsumer<StreamRecord, Exception> elementConsumer,
+            ThrowingConsumer<Watermark, Exception> watermarkConsumer)
+            throws Exception {
+        if (!isBlocked[inputIndex]) {
+            if (areBroadcastVariablesReady()) {
+                if (hasPendingElements[inputIndex]) {
+                    processPendingElementsAndWatermarks(
+                            inputIndex, elementConsumer, watermarkConsumer);
+                    hasPendingElements[inputIndex] = false;
+                }
+                watermarkConsumer.accept(watermark);
+
+            } else {
+                dataCacheWriters[inputIndex].addRecord(
+                        CacheElement.newWatermark(watermark.getTimestamp()));
+            }
+
+        } else {
+            while (!areBroadcastVariablesReady()) {
+                mailboxExecutor.yield();
+            }
+            watermarkConsumer.accept(watermark);
+        }
+    }
+
+    /**
+     * extracts common processing logic in subclasses' endInput(...).
+     *
+     * @param inputIndex input id, starts from zero.
+     * @param elementConsumer the consumer function of StreamRecord, i.e.,
+     *     operator.processElement(...).
+     * @param watermarkConsumer the consumer function of WaterMark, i.e.,
+     *     operator.processWatermark(...).
+     * @throws Exception possible exception.
+     */
+    @SuppressWarnings("rawtypes")
+    protected void endInputX(
+            int inputIndex,
+            ThrowingConsumer<StreamRecord, Exception> elementConsumer,
+            ThrowingConsumer<Watermark, Exception> watermarkConsumer)
+            throws Exception {
+        while (!areBroadcastVariablesReady()) {
+            mailboxExecutor.yield();
+        }
+        if (hasPendingElements[inputIndex]) {
+            processPendingElementsAndWatermarks(inputIndex, elementConsumer, 
watermarkConsumer);
+            hasPendingElements[inputIndex] = false;
+        }
+    }
+
+    /**
+     * processes the pending elements that are cached by {@link 
DataCacheWriter}.
+     *
+     * @param inputIndex input id, starts from zero.
+     * @param elementConsumer the consumer function of StreamRecord, i.e.,
+     *     operator.processElement(...).
+     * @param watermarkConsumer the consumer function of WaterMark, i.e.,
+     *     operator.processWatermark(...).
+     * @throws Exception possible exception.
+     */
+    @SuppressWarnings({"rawtypes", "unchecked"})
+    private void processPendingElementsAndWatermarks(
+            int inputIndex,
+            ThrowingConsumer<StreamRecord, Exception> elementConsumer,
+            ThrowingConsumer<Watermark, Exception> watermarkConsumer)
+            throws Exception {
+        dataCacheWriters[inputIndex].finishCurrentSegment();
+        List<Segment> pendingSegments = 
dataCacheWriters[inputIndex].getFinishSegments();
+        if (pendingSegments.size() != 0) {
+            DataCacheReader dataCacheReader =
+                    new DataCacheReader<>(
+                            new CacheElementTypeInfo<>(inTypes[inputIndex])
+                                    
.createSerializer(containingTask.getExecutionConfig()),
+                            basePath.getFileSystem(),
+                            pendingSegments);
+            while (dataCacheReader.hasNext()) {
+                CacheElement cacheElement = (CacheElement) 
dataCacheReader.next();
+                switch (cacheElement.getType()) {
+                    case RECORD:
+                        elementConsumer.accept(new 
StreamRecord(cacheElement.getRecord()));
+                        break;
+                    case WATERMARK:
+                        watermarkConsumer.accept(new 
Watermark(cacheElement.getWatermark()));
+                        break;
+                    default:
+                        throw new RuntimeException(
+                                "Unsupported CacheElement type: " + 
cacheElement.getType());
+                }
+            }
+        }
+    }
+
+    @Override
+    public void open() throws Exception {
+        wrappedOperator.open();
+    }
+
+    @Override
+    public void close() throws Exception {
+        wrappedOperator.close();
+        for (String name : broadcastStreamNames) {
+            BroadcastContext.remove(name + "-" + indexOfSubtask);
+        }
+    }
+
+    @Override
+    public void finish() throws Exception {
+        wrappedOperator.finish();
+    }
+
+    @Override
+    public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
+        wrappedOperator.prepareSnapshotPreBarrier(checkpointId);
+    }
+
+    @Override
+    public void initializeState(StreamTaskStateInitializer 
streamTaskStateManager)
+            throws Exception {
+        final TypeSerializer<?> keySerializer =
+                
streamConfig.getStateKeySerializer(containingTask.getUserCodeClassLoader());
+
+        StreamOperatorStateContext streamOperatorStateContext =
+                streamTaskStateManager.streamOperatorStateContext(
+                        getOperatorID(),
+                        getClass().getSimpleName(),
+                        parameters.getProcessingTimeService(),
+                        this,
+                        keySerializer,
+                        containingTask.getCancelables(),
+                        metrics,
+                        
streamConfig.getManagedMemoryFractionOperatorUseCaseOfSlot(
+                                ManagedMemoryUseCase.STATE_BACKEND,
+                                containingTask
+                                        .getEnvironment()
+                                        .getTaskManagerInfo()
+                                        .getConfiguration(),
+                                containingTask.getUserCodeClassLoader()),
+                        false);
+        stateHandler =
+                new StreamOperatorStateHandler(
+                        streamOperatorStateContext,
+                        containingTask.getExecutionConfig(),
+                        containingTask.getCancelables());
+        stateHandler.initializeOperatorState(this);
+
+        timeServiceManager = 
streamOperatorStateContext.internalTimerServiceManager();
+
+        broadcastVariablesReady = false;
+
+        wrappedOperator.initializeState(
+                (operatorID,
+                        operatorClassName,
+                        processingTimeService,
+                        keyContext,
+                        keySerializerX,
+                        streamTaskCloseableRegistry,
+                        metricGroup,
+                        managedMemoryFraction,
+                        isUsingCustomRawKeyedState) ->
+                        new ProxyStreamOperatorStateContext(
+                                streamOperatorStateContext, "wrapped-"));
+    }
+
+    @Override
+    public OperatorSnapshotFutures snapshotState(
+            long checkpointId,
+            long timestamp,
+            CheckpointOptions checkpointOptions,
+            CheckpointStreamFactory storageLocation)
+            throws Exception {
+        return stateHandler.snapshotState(
+                this,
+                Optional.ofNullable(timeServiceManager),
+                streamConfig.getOperatorName(),
+                checkpointId,
+                timestamp,
+                checkpointOptions,
+                storageLocation,
+                false);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked, rawtypes")
+    public void initializeState(StateInitializationContext 
stateInitializationContext)
+            throws Exception {
+        List<StatePartitionStreamProvider> inputs =
+                IteratorUtils.toList(
+                        
stateInitializationContext.getRawOperatorStateInputs().iterator());
+        Preconditions.checkState(
+                inputs.size() < 2, "The input from raw operator state should 
be one or zero.");
+        if (inputs.size() == 0) {
+            for (int i = 0; i < numInputs; i++) {
+                dataCacheWriters[i] =
+                        new DataCacheWriter(
+                                new CacheElementTypeInfo<>(inTypes[i])
+                                        
.createSerializer(containingTask.getExecutionConfig()),
+                                basePath.getFileSystem(),
+                                OperatorUtils.createDataCacheFileGenerator(
+                                        basePath, "cache", 
streamConfig.getOperatorID()));
+            }
+        } else {
+            InputStream inputStream = inputs.get(0).getStream();
+            DataInputStream dis =
+                    new DataInputStream(new 
NonClosingInputStreamDecorator(inputStream));
+            Preconditions.checkState(dis.readInt() == numInputs, "Number of 
input is wrong.");
+            for (int i = 0; i < numInputs; i++) {
+                DataCacheSnapshot dataCacheSnapshot =
+                        DataCacheSnapshot.recover(
+                                inputStream,
+                                basePath.getFileSystem(),
+                                OperatorUtils.createDataCacheFileGenerator(
+                                        basePath, "cache", 
streamConfig.getOperatorID()));
+                dataCacheWriters[i] =
+                        new DataCacheWriter(
+                                new CacheElementTypeInfo<>(inTypes[i])
+                                        
.createSerializer(containingTask.getExecutionConfig()),
+                                basePath.getFileSystem(),
+                                OperatorUtils.createDataCacheFileGenerator(
+                                        basePath, "cache", 
streamConfig.getOperatorID()),
+                                dataCacheSnapshot.getSegments());
+            }
+        }
+    }
+
+    @SuppressWarnings("unchecked")
+    @Override
+    public void snapshotState(StateSnapshotContext stateSnapshotContext) 
throws Exception {
+        if (wrappedOperator instanceof 
StreamOperatorStateHandler.CheckpointedStreamOperator) {
+            ((CheckpointedStreamOperator) 
wrappedOperator).snapshotState(stateSnapshotContext);
+        }
+
+        OperatorStateCheckpointOutputStream checkpointOutputStream =
+                stateSnapshotContext.getRawOperatorStateOutput();
+        checkpointOutputStream.startNewPartition();
+        try (DataOutputStream dos =
+                new DataOutputStream(new 
NonClosingOutpusStreamDecorator(checkpointOutputStream))) {
+            dos.writeInt(numInputs);
+        }
+        for (int i = 0; i < numInputs; i++) {
+            dataCacheWriters[i].finishCurrentSegment();
+            DataCacheSnapshot dataCacheSnapshot =
+                    new DataCacheSnapshot(
+                            basePath.getFileSystem(),
+                            null,
+                            dataCacheWriters[i].getFinishSegments());
+            dataCacheSnapshot.writeTo(checkpointOutputStream);
+        }
+    }
+
+    @Override
+    public void setKeyContextElement1(StreamRecord<?> record) throws Exception 
{
+        wrappedOperator.setKeyContextElement1(record);
+    }
+
+    @Override
+    public void setKeyContextElement2(StreamRecord<?> record) throws Exception 
{
+        wrappedOperator.setKeyContextElement2(record);
+    }
+
+    @Override
+    public OperatorMetricGroup getMetricGroup() {
+        return wrappedOperator.getMetricGroup();
+    }
+
+    @Override
+    public OperatorID getOperatorID() {
+        return wrappedOperator.getOperatorID();
+    }
+
+    @Override
+    public void notifyCheckpointComplete(long checkpointId) throws Exception {
+        wrappedOperator.notifyCheckpointComplete(checkpointId);
+    }
+
+    @Override
+    public void notifyCheckpointAborted(long checkpointId) throws Exception {
+        wrappedOperator.notifyCheckpointAborted(checkpointId);
+    }
+
+    @Override
+    public void setCurrentKey(Object key) {
+        wrappedOperator.setCurrentKey(key);
+    }
+
+    @Override
+    public Object getCurrentKey() {
+        return wrappedOperator.getCurrentKey();
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperator.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperator.java
new file mode 100644
index 0000000..170fccc
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperator.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractInput;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.Input;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+/** The operator that process all broadcast inputs and stores them in {@link 
BroadcastContext}. */
+public class BroadcastVariableReceiverOperator<OUT> extends 
AbstractStreamOperatorV2<OUT>
+        implements MultipleInputStreamOperator<OUT>, BoundedMultiInput, 
Serializable {
+
+    /** names of the broadcast data streams. */
+    private final String[] broadcastStreamNames;
+
+    /** output types of input data streams. */
+    private final TypeInformation<?>[] inTypes;
+
+    /** input list of the multi-input operator. */
+    @SuppressWarnings("rawtypes")
+    private final List<Input> inputList;
+
+    /** caches of the broadcast inputs. */
+    @SuppressWarnings("rawtypes")
+    private final List[] caches;
+
+    /** state storage of the broadcast inputs. */
+    private ListState<?>[] cacheStates;
+
+    /** cacheReady state storage of the broadcast inputs. */
+    private ListState<Boolean>[] cacheReadyStates;
+
+    @SuppressWarnings({"rawtypes", "unchecked"})
+    BroadcastVariableReceiverOperator(
+            StreamOperatorParameters<OUT> parameters,
+            String[] broadcastStreamNames,
+            TypeInformation<?>[] inTypes) {
+        super(parameters, broadcastStreamNames.length);
+        this.broadcastStreamNames = broadcastStreamNames;
+        this.inTypes = inTypes;
+        inputList = new ArrayList<>();
+        for (int i = 0; i < inTypes.length; i++) {
+            inputList.add(new ProxyInput(this, i + 1));
+        }
+        this.caches = new List[inTypes.length];
+        for (int i = 0; i < inTypes.length; i++) {
+            caches[i] = new ArrayList<>();
+        }
+        this.cacheStates = new ListState[inTypes.length];
+        this.cacheReadyStates = new ListState[inTypes.length];
+    }
+
+    @Override
+    public List<Input> getInputs() {
+        return inputList;
+    }
+
+    @Override
+    public void endInput(int i) {
+        BroadcastContext.markCacheFinished(
+                broadcastStreamNames[i - 1] + "-" + 
getRuntimeContext().getIndexOfThisSubtask());
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public void snapshotState(StateSnapshotContext context) throws Exception {
+        super.snapshotState(context);
+        for (int i = 0; i < inTypes.length; i++) {
+            cacheStates[i].clear();
+            cacheStates[i].addAll(caches[i]);
+            cacheReadyStates[i].clear();
+            boolean isCacheFinished =
+                    BroadcastContext.isCacheFinished(
+                            broadcastStreamNames[i]
+                                    + "-"
+                                    + 
getRuntimeContext().getIndexOfThisSubtask());
+            cacheReadyStates[i].add(isCacheFinished);
+        }
+    }
+
+    @Override
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    public void initializeState(StateInitializationContext context) throws 
Exception {
+        super.initializeState(context);
+        for (int i = 0; i < inTypes.length; i++) {
+            cacheStates[i] =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor("cache_data_" + i, inTypes[i]));
+            caches[i] = IteratorUtils.toList(cacheStates[i].get().iterator());
+
+            cacheReadyStates[i] =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "cache_ready_state_" + i,
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+            boolean cacheReady =
+                    OperatorStateUtils.getUniqueElement(
+                                    cacheReadyStates[i], "cache_ready_state_" 
+ i)
+                            .orElse(false);
+            BroadcastContext.putBroadcastVariable(
+                    broadcastStreamNames[i] + "-" + 
getRuntimeContext().getIndexOfThisSubtask(),
+                    Tuple2.of(cacheReady, caches[i]));
+        }
+    }
+
+    private class ProxyInput<IN, OT> extends AbstractInput<IN, OT> {
+
+        public ProxyInput(AbstractStreamOperatorV2<OT> owner, int inputId) {
+            super(owner, inputId);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void processElement(StreamRecord<IN> element) {
+            (caches[inputId - 1]).add(element.getValue());
+        }
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorFactory.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorFactory.java
new file mode 100644
index 0000000..75721c7
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorFactory.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+
+import java.io.Serializable;
+
+/** Factory class for {@link BroadcastVariableReceiverOperator}. */
+public class BroadcastVariableReceiverOperatorFactory<OUT>
+        extends AbstractStreamOperatorFactory<OUT> implements Serializable {
+
+    /** names of the broadcast data streams. */
+    private final String[] broadcastNames;
+
+    /** types of the broadcast data streams. */
+    private final TypeInformation<?>[] inTypes;
+
+    public BroadcastVariableReceiverOperatorFactory(
+            String[] broadcastNames, TypeInformation<?>[] inTypes) {
+        this.broadcastNames = broadcastNames;
+        this.inTypes = inTypes;
+    }
+
+    @Override
+    public <T extends StreamOperator<OUT>> T createStreamOperator(
+            StreamOperatorParameters<OUT> parameters) {
+        return (T) new BroadcastVariableReceiverOperator(parameters, 
broadcastNames, inTypes);
+    }
+
+    @Override
+    public Class<? extends StreamOperator> getStreamOperatorClass(ClassLoader 
classLoader) {
+        return BroadcastVariableReceiverOperator.class;
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
new file mode 100644
index 0000000..2e3f88d
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for {@link AbstractBroadcastWrapperOperator}. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+
+    /** names of the broadcast data streams. */
+    private final String[] broadcastStreamNames;
+
+    /** types of input data streams. */
+    private final TypeInformation<?>[] inTypes;
+
+    /** whether each input is blocked or not. */
+    private final boolean[] isBlocked;
+
+    @VisibleForTesting
+    public BroadcastWrapper(String[] broadcastStreamNames, 
TypeInformation<?>[] inTypes) {
+        this(broadcastStreamNames, inTypes, new boolean[inTypes.length]);
+    }
+
+    public BroadcastWrapper(
+            String[] broadcastStreamNames, TypeInformation<?>[] inTypes, 
boolean[] isBlocked) {
+        Preconditions.checkArgument(inTypes.length == isBlocked.length);
+        this.broadcastStreamNames = broadcastStreamNames;
+        this.inTypes = inTypes;
+        this.isBlocked = isBlocked;
+    }
+
+    @Override
+    public StreamOperator<T> wrap(
+            StreamOperatorParameters<T> operatorParameters,
+            StreamOperatorFactory<T> operatorFactory) {
+        Class<? extends StreamOperator> operatorClass =
+                
operatorFactory.getStreamOperatorClass(getClass().getClassLoader());
+        if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+            return new OneInputBroadcastWrapperOperator<>(
+                    operatorParameters, operatorFactory, broadcastStreamNames, 
inTypes, isBlocked);
+        } else if 
(TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+            return new TwoInputBroadcastWrapperOperator<>(
+                    operatorParameters, operatorFactory, broadcastStreamNames, 
inTypes, isBlocked);
+        } else {
+            throw new UnsupportedOperationException(
+                    "Unsupported operator class for with-broadcast wrapper: " 
+ operatorClass);
+        }
+    }
+
+    @Override
+    public <KEY> KeySelector<T, KEY> wrapKeySelector(KeySelector<T, KEY> 
keySelector) {
+        return keySelector;
+    }
+
+    @Override
+    public StreamPartitioner<T> wrapStreamPartitioner(StreamPartitioner<T> 
streamPartitioner) {
+        return streamPartitioner;
+    }
+
+    @Override
+    public OutputTag<T> wrapOutputTag(OutputTag<T> outputTag) {
+        return outputTag;
+    }
+
+    @Override
+    public TypeInformation<T> getWrappedTypeInfo(TypeInformation<T> typeInfo) {
+        return typeInfo;
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
new file mode 100644
index 0000000..f1ffe00
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.operator.OperatorUtils;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+
+/** Wrapper for {@link OneInputStreamOperator}. */
+public class OneInputBroadcastWrapperOperator<IN, OUT>
+        extends AbstractBroadcastWrapperOperator<OUT, 
OneInputStreamOperator<IN, OUT>>
+        implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
+
+    OneInputBroadcastWrapperOperator(
+            StreamOperatorParameters<OUT> parameters,
+            StreamOperatorFactory<OUT> operatorFactory,
+            String[] broadcastStreamNames,
+            TypeInformation<?>[] inTypes,
+            boolean[] isBlocking) {
+        super(parameters, operatorFactory, broadcastStreamNames, inTypes, 
isBlocking);
+    }
+
+    @Override
+    public void processElement(StreamRecord<IN> streamRecord) throws Exception 
{
+        processElementX(
+                streamRecord,
+                0,
+                wrappedOperator::processElement,
+                wrappedOperator::processWatermark);
+    }
+
+    @Override
+    public void endInput() throws Exception {
+        endInputX(0, wrappedOperator::processElement, 
wrappedOperator::processWatermark);
+        OperatorUtils.processOperatorOrUdfIfSatisfy(
+                wrappedOperator, BoundedOneInput.class, 
BoundedOneInput::endInput);
+    }
+
+    @Override
+    public void processWatermark(Watermark watermark) throws Exception {
+        processWatermarkX(
+                watermark, 0, wrappedOperator::processElement, 
wrappedOperator::processWatermark);
+    }
+
+    @Override
+    public void processWatermarkStatus(WatermarkStatus watermarkStatus) throws 
Exception {
+        wrappedOperator.processWatermarkStatus(watermarkStatus);
+    }
+
+    @Override
+    public void processLatencyMarker(LatencyMarker latencyMarker) throws 
Exception {
+        wrappedOperator.processLatencyMarker(latencyMarker);
+    }
+
+    @Override
+    public void setKeyContextElement(StreamRecord<IN> streamRecord) throws 
Exception {
+        wrappedOperator.setKeyContextElement(streamRecord);
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperator.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperator.java
new file mode 100644
index 0000000..07871d4
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperator.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.operator.OperatorUtils;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+
+/** Wrapper for {@link TwoInputStreamOperator}. */
+public class TwoInputBroadcastWrapperOperator<IN1, IN2, OUT>
+        extends AbstractBroadcastWrapperOperator<OUT, 
TwoInputStreamOperator<IN1, IN2, OUT>>
+        implements TwoInputStreamOperator<IN1, IN2, OUT>, BoundedMultiInput {
+
+    TwoInputBroadcastWrapperOperator(
+            StreamOperatorParameters<OUT> parameters,
+            StreamOperatorFactory<OUT> operatorFactory,
+            String[] broadcastStreamNames,
+            TypeInformation<?>[] inTypes,
+            boolean[] isBlocking) {
+        super(parameters, operatorFactory, broadcastStreamNames, inTypes, 
isBlocking);
+    }
+
+    @Override
+    public void processElement1(StreamRecord<IN1> streamRecord) throws 
Exception {
+        processElementX(
+                streamRecord,
+                0,
+                wrappedOperator::processElement1,
+                wrappedOperator::processWatermark1);
+    }
+
+    @Override
+    public void processElement2(StreamRecord<IN2> streamRecord) throws 
Exception {
+        processElementX(
+                streamRecord,
+                1,
+                wrappedOperator::processElement2,
+                wrappedOperator::processWatermark2);
+    }
+
+    @Override
+    public void endInput(int inputId) throws Exception {
+        if (inputId == 1) {
+            endInputX(
+                    inputId - 1,
+                    wrappedOperator::processElement1,
+                    wrappedOperator::processWatermark1);
+        } else {
+            endInputX(
+                    inputId - 1,
+                    wrappedOperator::processElement2,
+                    wrappedOperator::processWatermark2);
+        }
+        OperatorUtils.processOperatorOrUdfIfSatisfy(
+                wrappedOperator,
+                BoundedMultiInput.class,
+                boundedMultipleInput -> 
boundedMultipleInput.endInput(inputId));
+    }
+
+    @Override
+    public void processWatermark1(Watermark watermark) throws Exception {
+        processWatermarkX(
+                watermark, 0, wrappedOperator::processElement1, 
wrappedOperator::processWatermark1);
+    }
+
+    @Override
+    public void processWatermark2(Watermark watermark) throws Exception {
+        processWatermarkX(
+                watermark, 1, wrappedOperator::processElement2, 
wrappedOperator::processWatermark2);
+    }
+
+    @Override
+    public void processLatencyMarker1(LatencyMarker latencyMarker) throws 
Exception {
+        wrappedOperator.processLatencyMarker1(latencyMarker);
+    }
+
+    @Override
+    public void processLatencyMarker2(LatencyMarker latencyMarker) throws 
Exception {
+        wrappedOperator.processLatencyMarker2(latencyMarker);
+    }
+
+    @Override
+    public void processWatermarkStatus1(WatermarkStatus watermarkStatus) 
throws Exception {
+        wrappedOperator.processWatermarkStatus1(watermarkStatus);
+    }
+
+    @Override
+    public void processWatermarkStatus2(WatermarkStatus watermarkStatus) 
throws Exception {
+        wrappedOperator.processWatermarkStatus2(watermarkStatus);
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/typeinfo/CacheElement.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/typeinfo/CacheElement.java
new file mode 100644
index 0000000..63cf6c7
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/typeinfo/CacheElement.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.typeinfo;
+
+/**
+ * The wrapper class for possible cached elements used in {@link
+ * 
org.apache.flink.ml.common.broadcast.operator.AbstractBroadcastWrapperOperator}.
 It could be
+ * either {@link org.apache.flink.streaming.api.watermark.Watermark}, {@link
+ * org.apache.flink.streaming.runtime.streamrecord.StreamRecord}, etc.
+ *
+ * @param <T> the record type.
+ */
+public class CacheElement<T> {
+    private T record;
+    private long watermark;
+    private Type type;
+
+    public CacheElement(T record, long watermark, Type type) {
+        this.record = record;
+        this.watermark = watermark;
+        this.type = type;
+    }
+
+    public static <T> CacheElement<T> newRecord(T record) {
+        return new CacheElement<>(record, -1, Type.RECORD);
+    }
+
+    public static <T> CacheElement<T> newWatermark(long watermark) {
+        return new CacheElement<>(null, watermark, Type.WATERMARK);
+    }
+
+    public T getRecord() {
+        return record;
+    }
+
+    public void setRecord(T record) {
+        this.record = record;
+    }
+
+    public long getWatermark() {
+        return watermark;
+    }
+
+    public void setWatermark(long watermark) {
+        this.watermark = watermark;
+    }
+
+    public Type getType() {
+        return type;
+    }
+
+    public void setType(Type type) {
+        this.type = type;
+    }
+
+    /** The type of cached elements. */
+    public enum Type {
+        /** record type. */
+        RECORD,
+        /** watermark type. */
+        WATERMARK,
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/typeinfo/CacheElementSerializer.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/typeinfo/CacheElementSerializer.java
new file mode 100644
index 0000000..ed600c1
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/typeinfo/CacheElementSerializer.java
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.typeinfo;
+
+import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.ml.common.broadcast.typeinfo.CacheElement.Type;
+
+import java.io.IOException;
+import java.util.Objects;
+
+/**
+ * TypeSerializer for {@link CacheElement}.
+ *
+ * @param <T> the record type.
+ */
+public class CacheElementSerializer<T> extends TypeSerializer<CacheElement<T>> 
{
+
+    private final TypeSerializer<T> recordSerializer;
+
+    public CacheElementSerializer(TypeSerializer<T> recordSerializer) {
+        this.recordSerializer = recordSerializer;
+    }
+
+    @Override
+    public boolean isImmutableType() {
+        return false;
+    }
+
+    @Override
+    public TypeSerializer<CacheElement<T>> duplicate() {
+        return new CacheElementSerializer<>(recordSerializer.duplicate());
+    }
+
+    @Override
+    public CacheElement<T> createInstance() {
+        return null;
+    }
+
+    @Override
+    public CacheElement<T> copy(CacheElement<T> from) {
+        switch (from.getType()) {
+            case RECORD:
+                return 
CacheElement.newRecord(recordSerializer.copy(from.getRecord()));
+            case WATERMARK:
+                return CacheElement.newWatermark(from.getWatermark());
+            default:
+                throw new RuntimeException(
+                        "Unsupported Record or Watermark type " + 
from.getType());
+        }
+    }
+
+    @Override
+    public CacheElement<T> copy(CacheElement<T> from, CacheElement<T> reuse) {
+        switch (from.getType()) {
+            case RECORD:
+                if (reuse.getRecord() != null) {
+                    recordSerializer.copy(from.getRecord(), reuse.getRecord());
+                } else {
+                    reuse.setRecord(recordSerializer.copy(from.getRecord()));
+                }
+                break;
+            case WATERMARK:
+                reuse.setWatermark(from.getWatermark());
+                break;
+            default:
+                throw new RuntimeException(
+                        "Unsupported Record or Watermark type " + 
from.getType());
+        }
+
+        return reuse;
+    }
+
+    @Override
+    public int getLength() {
+        return -1;
+    }
+
+    @Override
+    public void serialize(CacheElement<T> record, DataOutputView target) 
throws IOException {
+        target.writeByte((byte) record.getType().ordinal());
+
+        switch (record.getType()) {
+            case RECORD:
+                recordSerializer.serialize(record.getRecord(), target);
+                break;
+            case WATERMARK:
+                LongSerializer.INSTANCE.serialize(record.getWatermark(), 
target);
+                break;
+            default:
+                throw new RuntimeException(
+                        "Unsupported Record or Watermark type " + 
record.getType());
+        }
+    }
+
+    @Override
+    public CacheElement<T> deserialize(DataInputView source) throws 
IOException {
+        int type = source.readByte();
+        switch (CacheElement.Type.values()[type]) {
+            case RECORD:
+                T value = recordSerializer.deserialize(source);
+                return CacheElement.newRecord(value);
+            case WATERMARK:
+                long watermark = LongSerializer.INSTANCE.deserialize(source);
+                return CacheElement.newWatermark(watermark);
+            default:
+                throw new RuntimeException("Unsupported Record or Watermark 
type " + type);
+        }
+    }
+
+    @Override
+    public CacheElement<T> deserialize(CacheElement<T> reuse, DataInputView 
source)
+            throws IOException {
+        int type = source.readByte();
+        switch (CacheElement.Type.values()[type]) {
+            case RECORD:
+                reuse.setType(Type.RECORD);
+                reuse.setRecord(recordSerializer.deserialize(source));
+                break;
+            case WATERMARK:
+                reuse.setType(Type.WATERMARK);
+                
reuse.setWatermark(LongSerializer.INSTANCE.deserialize(source));
+                break;
+            default:
+                throw new RuntimeException("Unsupported Record or Watermark 
type " + type);
+        }
+        return reuse;
+    }
+
+    @Override
+    public void copy(DataInputView source, DataOutputView target) throws 
IOException {
+        CacheElement<T> cacheElement = deserialize(source);
+        serialize(cacheElement, target);
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (this == obj) {
+            return true;
+        }
+
+        if (obj == null || getClass() != obj.getClass()) {
+            return false;
+        }
+        CacheElementSerializer<?> that = (CacheElementSerializer<?>) obj;
+        return Objects.equals(recordSerializer, that.recordSerializer);
+    }
+
+    @Override
+    public int hashCode() {
+        return recordSerializer != null ? recordSerializer.hashCode() : 0;
+    }
+
+    @Override
+    public TypeSerializerSnapshot<CacheElement<T>> snapshotConfiguration() {
+        return new CacheElementSerializerSnapshot<>();
+    }
+
+    /** The serializer snapshot class for {@link CacheElementSerializer}. */
+    private static final class CacheElementSerializerSnapshot<T>
+            extends CompositeTypeSerializerSnapshot<CacheElement<T>, 
CacheElementSerializer<T>> {
+
+        private static final int CURRENT_VERSION = 1;
+
+        public CacheElementSerializerSnapshot() {
+            super(CacheElementSerializer.class);
+        }
+
+        @Override
+        protected int getCurrentOuterSnapshotVersion() {
+            return CURRENT_VERSION;
+        }
+
+        @Override
+        protected TypeSerializer<?>[] getNestedSerializers(
+                CacheElementSerializer<T> tIterationRecordSerializer) {
+            return new TypeSerializer[] 
{tIterationRecordSerializer.recordSerializer};
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        protected CacheElementSerializer<T> 
createOuterSerializerWithNestedSerializers(
+                TypeSerializer<?>[] typeSerializers) {
+            TypeSerializer<T> elementSerializer = (TypeSerializer<T>) 
typeSerializers[0];
+            return new CacheElementSerializer<>(elementSerializer);
+        }
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/typeinfo/CacheElementTypeInfo.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/typeinfo/CacheElementTypeInfo.java
new file mode 100644
index 0000000..55dc4de
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/typeinfo/CacheElementTypeInfo.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.typeinfo;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+
+import java.util.Objects;
+
+/**
+ * TypeInformation for {@link CacheElement}.
+ *
+ * @param <T> the record type.
+ */
+public class CacheElementTypeInfo<T> extends TypeInformation<CacheElement<T>> {
+
+    private final TypeInformation<T> recordTypeInfo;
+
+    public CacheElementTypeInfo(TypeInformation<T> recordTypeInfo) {
+        this.recordTypeInfo = recordTypeInfo;
+    }
+
+    @Override
+    public boolean isBasicType() {
+        return false;
+    }
+
+    @Override
+    public boolean isTupleType() {
+        return false;
+    }
+
+    @Override
+    public int getArity() {
+        return 1;
+    }
+
+    @Override
+    public int getTotalFields() {
+        return 1;
+    }
+
+    @Override
+    public Class<CacheElement<T>> getTypeClass() {
+        return (Class) CacheElement.class;
+    }
+
+    @Override
+    public boolean isKeyType() {
+        return false;
+    }
+
+    @Override
+    public TypeSerializer<CacheElement<T>> createSerializer(ExecutionConfig 
config) {
+        return new 
CacheElementSerializer<>(recordTypeInfo.createSerializer(config));
+    }
+
+    @Override
+    public String toString() {
+        return "RecordOrWatermark Type";
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (this == obj) {
+            return true;
+        }
+
+        if (null == obj || getClass() != obj.getClass()) {
+            return false;
+        }
+
+        CacheElementTypeInfo<T> that = (CacheElementTypeInfo<T>) obj;
+        return Objects.equals(recordTypeInfo, that.recordTypeInfo);
+    }
+
+    @Override
+    public int hashCode() {
+        return recordTypeInfo != null ? recordTypeInfo.hashCode() : 0;
+    }
+
+    @Override
+    public boolean canEqual(Object obj) {
+        return obj instanceof CacheElementTypeInfo;
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/BroadcastUtilsTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/BroadcastUtilsTest.java
new file mode 100644
index 0000000..08f9498
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/BroadcastUtilsTest.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.iteration.config.IterationOptions;
+import org.apache.flink.ml.common.broadcast.operator.TestOneInputOp;
+import org.apache.flink.ml.common.broadcast.operator.TestTwoInputOp;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.streaming.api.CheckpointingMode;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/** Tests the {@link BroadcastUtils}. */
+public class BroadcastUtilsTest {
+
+    @Rule public TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final int NUM_RECORDS_PER_PARTITION = 10;
+
+    private static final int NUM_TM = 2;
+
+    private static final int NUM_SLOT = 2;
+
+    private static final String[] BROADCAST_NAMES = new String[] {"source1", 
"source2"};
+
+    private static final List<Integer> BROADCAST_INPUT =
+            IntStream.range(0, NUM_TM * NUM_SLOT * NUM_RECORDS_PER_PARTITION)
+                    .boxed()
+                    .collect(Collectors.toList());
+
+    private MiniClusterConfiguration createMiniClusterConfiguration() throws 
IOException {
+        Configuration configuration = new Configuration();
+        configuration.set(RestOptions.PORT, 18082);
+        configuration.set(
+                IterationOptions.DATA_CACHE_PATH,
+                "file://" + tempFolder.newFolder().getAbsolutePath());
+        configuration.set(
+                
ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        return new MiniClusterConfiguration.Builder()
+                .setConfiguration(configuration)
+                .setNumTaskManagers(NUM_TM)
+                .setNumSlotsPerTaskManager(NUM_SLOT)
+                .build();
+    }
+
+    @Test
+    public void testOneInputGraph() throws Exception {
+        try (MiniCluster miniCluster = new 
MiniCluster(createMiniClusterConfiguration())) {
+            miniCluster.start();
+            JobGraph jobGraph = getJobGraph(1);
+            miniCluster.executeJobBlocking(jobGraph);
+        }
+    }
+
+    @Test
+    public void testTwoInputGraph() throws Exception {
+        try (MiniCluster miniCluster = new 
MiniCluster(createMiniClusterConfiguration())) {
+            miniCluster.start();
+            JobGraph jobGraph = getJobGraph(2);
+            miniCluster.executeJobBlocking(jobGraph);
+        }
+    }
+
+    private JobGraph getJobGraph(int numNonBroadcastInputs) {
+        StreamExecutionEnvironment env =
+                StreamExecutionEnvironment.getExecutionEnvironment(
+                        new Configuration() {
+                            {
+                                this.set(
+                                        ExecutionCheckpointingOptions
+                                                
.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
+                                        true);
+                            }
+                        });
+        env.enableCheckpointing(500, CheckpointingMode.EXACTLY_ONCE);
+        env.setParallelism(NUM_SLOT * NUM_TM);
+
+        DataStream<Integer> source1 = env.addSource(new 
TestSource(NUM_RECORDS_PER_PARTITION));
+        DataStream<Integer> source2 = env.addSource(new 
TestSource(NUM_RECORDS_PER_PARTITION));
+        HashMap<String, DataStream<?>> bcStreamsMap = new HashMap<>();
+        bcStreamsMap.put(BROADCAST_NAMES[0], source1);
+        bcStreamsMap.put(BROADCAST_NAMES[1], source2);
+
+        List<DataStream<?>> inputList = new ArrayList<>(1);
+        // create a deadlock.
+        inputList.add(source1);
+        for (int i = 0; i < numNonBroadcastInputs - 1; i++) {
+            inputList.add(env.addSource(new 
TestSource(NUM_RECORDS_PER_PARTITION)));
+        }
+
+        Function<List<DataStream<?>>, DataStream<Integer>> func = 
getFunc(numNonBroadcastInputs);
+
+        DataStream<Integer> result =
+                BroadcastUtils.withBroadcastStream(inputList, bcStreamsMap, 
func);
+
+        List<Integer> expectedNumSequence =
+                new ArrayList<>(
+                        NUM_TM * NUM_SLOT * NUM_RECORDS_PER_PARTITION * 
numNonBroadcastInputs);
+        for (int i = 0; i < NUM_TM * NUM_SLOT * NUM_RECORDS_PER_PARTITION; 
i++) {
+            for (int j = 0; j < numNonBroadcastInputs; j++) {
+                expectedNumSequence.add(i);
+            }
+        }
+        result.addSink(new TestSink(expectedNumSequence)).setParallelism(1);
+
+        return env.getStreamGraph().getJobGraph();
+    }
+
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    private static Function<List<DataStream<?>>, DataStream<Integer>> 
getFunc(int numInputs) {
+        if (numInputs == 1) {
+            return dataStreams -> {
+                DataStream input = dataStreams.get(0);
+                return input.transform(
+                                "one-input",
+                                BasicTypeInfo.INT_TYPE_INFO,
+                                new TestOneInputOp(
+                                        new AbstractRichFunction() {},
+                                        BROADCAST_NAMES,
+                                        Arrays.asList(BROADCAST_INPUT, 
BROADCAST_INPUT)))
+                        .name("broadcast");
+            };
+        } else if (numInputs == 2) {
+            return dataStreams -> {
+                DataStream input1 = dataStreams.get(0);
+                DataStream input2 = dataStreams.get(1);
+                return input1.connect(input2)
+                        .transform(
+                                "two-input",
+                                BasicTypeInfo.INT_TYPE_INFO,
+                                new TestTwoInputOp(
+                                        new AbstractRichFunction() {},
+                                        BROADCAST_NAMES,
+                                        Arrays.asList(BROADCAST_INPUT, 
BROADCAST_INPUT)))
+                        .name("broadcast");
+            };
+        }
+        return null;
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/TestSink.java 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/TestSink.java
new file mode 100644
index 0000000..e1043cc
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/TestSink.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import 
org.apache.flink.ml.common.broadcast.operator.BroadcastVariableReceiverOperatorTest;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.List;
+
+/**
+ * Utility class to check the result of a sink. It throws an exception if the 
records received are
+ * not as expected.
+ */
+public class TestSink extends RichSinkFunction<Integer> implements 
CheckpointedFunction {
+
+    private final List<Integer> expectReceivedRecords;
+
+    private List<Integer> receivedRecords;
+
+    private ListState<Integer> receivedRecordsState;
+
+    public TestSink(List<Integer> expectReceivedRecords) {
+        this.expectReceivedRecords = expectReceivedRecords;
+    }
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+    }
+
+    @Override
+    public void invoke(Integer value, Context context) {
+        receivedRecords.add(value);
+    }
+
+    @Override
+    public void finish() {
+        
BroadcastVariableReceiverOperatorTest.compareLists(expectReceivedRecords, 
receivedRecords);
+    }
+
+    @Override
+    public void close() {}
+
+    @Override
+    public void snapshotState(FunctionSnapshotContext functionSnapshotContext) 
throws Exception {
+        this.receivedRecordsState.clear();
+        this.receivedRecordsState.addAll(receivedRecords);
+    }
+
+    @Override
+    public void initializeState(FunctionInitializationContext 
functionInitializationContext)
+            throws Exception {
+        receivedRecordsState =
+                functionInitializationContext
+                        .getOperatorStateStore()
+                        .getListState(
+                                new ListStateDescriptor<>(
+                                        "receivedRecords", 
BasicTypeInfo.INT_TYPE_INFO));
+        receivedRecords = 
IteratorUtils.toList(receivedRecordsState.get().iterator());
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/TestSource.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/TestSource.java
new file mode 100644
index 0000000..46e1880
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/TestSource.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+
+import java.util.Iterator;
+
+/**
+ * Utility class that generates int stream and also throws exceptions to test 
the fail over. In
+ * detail, given ${numElementsPerPartition}, this class generate a number 
sequence with elements
+ * range in [0, StreamExecutionEnvironment.getParallelism() * 
numElementsPerPartition).
+ *
+ * <p>For example, when the parallelism is 2 and the 
${numElementsPerPartition} is 5, this class
+ * generates {0,1,2,3,4,5,6,7,8,9}.
+ */
+public class TestSource extends RichParallelSourceFunction<Integer>
+        implements CheckpointedFunction {
+
+    private static volatile boolean hasThrown = false;
+
+    private ListState<Integer> currentIdxState;
+
+    private Integer currentIdx;
+
+    private Integer mod, numPartitions, numElementsPerPartition;
+
+    private transient volatile boolean running = true;
+
+    public TestSource(int numElementsPerPartition) {
+        this.numElementsPerPartition = numElementsPerPartition;
+    }
+
+    @Override
+    public void open(Configuration parameters) {
+        this.mod = getRuntimeContext().getIndexOfThisSubtask();
+        this.numPartitions = getRuntimeContext().getNumberOfParallelSubtasks();
+        running = true;
+    }
+
+    @Override
+    public void snapshotState(FunctionSnapshotContext functionSnapshotContext) 
throws Exception {
+        this.currentIdxState.clear();
+        this.currentIdxState.add(currentIdx);
+    }
+
+    @Override
+    public void initializeState(FunctionInitializationContext 
functionInitializationContext)
+            throws Exception {
+        currentIdxState =
+                functionInitializationContext
+                        .getOperatorStateStore()
+                        .getListState(
+                                new ListStateDescriptor<>(
+                                        "currentIdx", 
BasicTypeInfo.INT_TYPE_INFO));
+        Iterator<Integer> iterator = currentIdxState.get().iterator();
+        currentIdx = 0;
+        if (iterator.hasNext()) {
+            currentIdx = iterator.next();
+        }
+    }
+
+    @Override
+    public void run(SourceContext<Integer> sourceContext) throws Exception {
+        while (running && currentIdx < numElementsPerPartition) {
+            synchronized (sourceContext.getCheckpointLock()) {
+                sourceContext.collect(currentIdx * numPartitions + mod);
+                currentIdx++;
+            }
+            Thread.sleep(1);
+            if (currentIdx == numElementsPerPartition / 2 && (!hasThrown)) {
+                hasThrown = true;
+                throw new RuntimeException("Failing source");
+            }
+        }
+    }
+
+    @Override
+    public void cancel() {
+        running = false;
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorTest.java
new file mode 100644
index 0000000..2457471
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorTest.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.MultipleInputStreamTask;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
+import 
org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.List;
+
+/** Tests the {@link BroadcastVariableReceiverOperator}. */
+public class BroadcastVariableReceiverOperatorTest {
+
+    private static final String[] BROADCAST_NAMES = new String[] {"source1", 
"source2"};
+
+    private static final TypeInformation<?>[] TYPE_INFORMATIONS =
+            new TypeInformation[] {BasicTypeInfo.INT_TYPE_INFO, 
BasicTypeInfo.INT_TYPE_INFO};
+
+    @Test
+    public void testCacheStreamOperator() throws Exception {
+        OperatorID operatorId = new OperatorID();
+
+        try (StreamTaskMailboxTestHarness<Integer> harness =
+                new StreamTaskMailboxTestHarnessBuilder<>(
+                                MultipleInputStreamTask::new, 
BasicTypeInfo.INT_TYPE_INFO)
+                        .addInput(BasicTypeInfo.INT_TYPE_INFO)
+                        .addInput(BasicTypeInfo.INT_TYPE_INFO)
+                        .setupOutputForSingletonOperatorChain(
+                                new BroadcastVariableReceiverOperatorFactory<>(
+                                        BROADCAST_NAMES, TYPE_INFORMATIONS),
+                                operatorId)
+                        .build()) {
+            harness.processElement(new StreamRecord<>(1, 2), 0);
+            harness.processElement(new StreamRecord<>(2, 3), 0);
+            harness.processElement(new StreamRecord<>(3, 2), 1);
+            harness.processElement(new StreamRecord<>(4, 2), 1);
+            harness.processElement(new StreamRecord<>(5, 3), 1);
+            boolean cacheReady1 = 
BroadcastContext.isCacheFinished(BROADCAST_NAMES[0] + "-" + 0);
+            boolean cacheReady2 = 
BroadcastContext.isCacheFinished(BROADCAST_NAMES[1] + "-" + 0);
+            // check broadcast inputs before task finishes.
+            Assert.assertFalse(cacheReady1 || cacheReady2);
+
+            harness.waitForTaskCompletion();
+            List<?> cache1 = 
BroadcastContext.getBroadcastVariable(BROADCAST_NAMES[0] + "-" + 0);
+            List<?> cache2 = 
BroadcastContext.getBroadcastVariable(BROADCAST_NAMES[1] + "-" + 0);
+            // check broadcast inputs after task finishes.
+            compareLists(Arrays.asList(1, 2), cache1);
+            compareLists(Arrays.asList(3, 4, 5), cache2);
+        }
+    }
+
+    public static void compareLists(List<Integer> expected, List<?> actual) {
+        int[] actualInts =
+                actual.stream().map(x -> (Integer) 
x).mapToInt(Integer::intValue).toArray();
+        Arrays.sort(actualInts);
+        int[] expectedInts = 
expected.stream().mapToInt(Integer::intValue).toArray();
+        Arrays.sort(expectedInts);
+        Assert.assertArrayEquals(expectedInts, actualInts);
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapperOperatorFactory.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapperOperatorFactory.java
new file mode 100644
index 0000000..40a56b0
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapperOperatorFactory.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+
+/** Factory class for {@link AbstractBroadcastWrapperOperator}. */
+class BroadcastWrapperOperatorFactory<OUT> extends 
AbstractStreamOperatorFactory<OUT> {
+
+    private final StreamOperatorFactory<OUT> operatorFactory;
+
+    private final BroadcastWrapper<OUT> wrapper;
+
+    public BroadcastWrapperOperatorFactory(
+            StreamOperatorFactory<OUT> operatorFactory, BroadcastWrapper<OUT> 
wrapper) {
+        this.operatorFactory = operatorFactory;
+        this.wrapper = wrapper;
+    }
+
+    @Override
+    public <T extends StreamOperator<OUT>> T createStreamOperator(
+            StreamOperatorParameters<OUT> parameters) {
+        return (T) wrapper.wrap(parameters, operatorFactory);
+    }
+
+    @Override
+    public Class<? extends StreamOperator> getStreamOperatorClass(ClassLoader 
classLoader) {
+        return AbstractBroadcastWrapperOperator.class;
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperatorTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperatorTest.java
new file mode 100644
index 0000000..7d64746
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperatorTest.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.iteration.config.IterationOptions;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
+import 
org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+import org.apache.flink.streaming.util.TestHarnessUtil;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Queue;
+import java.util.concurrent.ConcurrentLinkedQueue;
+
+/** Tests the {@link OneInputBroadcastWrapperOperator}. */
+public class OneInputBroadcastWrapperOperatorTest {
+
+    @Rule public TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final String[] BROADCAST_NAMES = new String[] {"source1", 
"source2"};
+
+    private static final TypeInformation<?>[] TYPE_INFORMATIONS =
+            new TypeInformation[] {BasicTypeInfo.INT_TYPE_INFO, 
BasicTypeInfo.INT_TYPE_INFO};
+
+    private static final List<Integer> SOURCE_1 = Collections.singletonList(1);
+
+    private static final List<Integer> SOURCE_2 = Arrays.asList(1, 2, 3);
+
+    private static class MyRichFunction extends AbstractRichFunction {}
+
+    @Test
+    public void testProcessElements() throws Exception {
+        OneInputStreamOperator<Integer, Integer> inputOp =
+                new TestOneInputOp(
+                        new MyRichFunction(), BROADCAST_NAMES, 
Arrays.asList(SOURCE_1, SOURCE_2));
+        BroadcastWrapper<Integer> broadcastWrapper =
+                new BroadcastWrapper<>(BROADCAST_NAMES, TYPE_INFORMATIONS);
+        BroadcastWrapperOperatorFactory<Integer> wrapperFactory =
+                new BroadcastWrapperOperatorFactory<>(
+                        SimpleOperatorFactory.of(inputOp), broadcastWrapper);
+        OperatorID operatorId = new OperatorID();
+
+        try (StreamTaskMailboxTestHarness<Integer> harness =
+                new StreamTaskMailboxTestHarnessBuilder<>(
+                                OneInputStreamTask::new, 
BasicTypeInfo.INT_TYPE_INFO)
+                        .addInput(BasicTypeInfo.INT_TYPE_INFO)
+                        .setupOutputForSingletonOperatorChain(wrapperFactory, 
operatorId)
+                        .buildUnrestored()) {
+            harness.getStreamTask()
+                    .getEnvironment()
+                    .getTaskManagerInfo()
+                    .getConfiguration()
+                    .set(
+                            IterationOptions.DATA_CACHE_PATH,
+                            "file://" + 
tempFolder.newFolder().getAbsolutePath());
+            harness.getStreamTask().restore();
+
+            BroadcastContext.putBroadcastVariable(
+                    BROADCAST_NAMES[0] + "-" + 0, Tuple2.of(true, SOURCE_1));
+            BroadcastContext.putBroadcastVariable(
+                    BROADCAST_NAMES[1] + "-" + 0, Tuple2.of(true, SOURCE_2));
+
+            Queue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
+            for (int i = 0; i < 5; ++i) {
+                harness.processElement(new StreamRecord<>(i, 1000), 0);
+                expectedOutput.add(new StreamRecord<>(i, 1000));
+            }
+            TestHarnessUtil.assertOutputEquals(
+                    "Output was not correct.", expectedOutput, 
harness.getOutput());
+        }
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/TestOneInputOp.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/TestOneInputOp.java
new file mode 100644
index 0000000..69ccbd4
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/TestOneInputOp.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.functions.RichFunction;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import java.util.List;
+
+/** Utility class used for testing {@link OneInputBroadcastWrapperOperator}. */
+public class TestOneInputOp extends AbstractUdfStreamOperator<Integer, 
RichFunction>
+        implements OneInputStreamOperator<Integer, Integer> {
+
+    private final String[] broadcastNames;
+
+    private final List<List<Integer>> expectedBroadcastInputs;
+
+    public TestOneInputOp(
+            RichFunction func,
+            String[] broadcastNames,
+            List<List<Integer>> expectedBroadcastInputs) {
+        super(func);
+        this.broadcastNames = broadcastNames;
+        this.expectedBroadcastInputs = expectedBroadcastInputs;
+    }
+
+    @Override
+    public void processElement(StreamRecord<Integer> streamRecord) {
+        for (int i = 0; i < broadcastNames.length; i++) {
+            BroadcastVariableReceiverOperatorTest.compareLists(
+                    expectedBroadcastInputs.get(i),
+                    
userFunction.getRuntimeContext().getBroadcastVariable(broadcastNames[i]));
+        }
+        output.collect(streamRecord);
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/TestTwoInputOp.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/TestTwoInputOp.java
new file mode 100644
index 0000000..bf7b102
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/TestTwoInputOp.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.functions.RichFunction;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import java.util.List;
+
+/** Utility class used for testing {@link TwoInputBroadcastWrapperOperator}. */
+public class TestTwoInputOp extends AbstractUdfStreamOperator<Integer, 
RichFunction>
+        implements TwoInputStreamOperator<Integer, Integer, Integer> {
+
+    private final String[] broadcastNames;
+
+    private final List<List<Integer>> expectedBroadcastInputs;
+
+    public TestTwoInputOp(
+            RichFunction func, String[] broadcastNames, List<List<Integer>> 
expectedSizes) {
+        super(func);
+        this.broadcastNames = broadcastNames;
+        this.expectedBroadcastInputs = expectedSizes;
+    }
+
+    @Override
+    public void processElement1(StreamRecord<Integer> streamRecord) {
+        for (int i = 0; i < broadcastNames.length; i++) {
+            BroadcastVariableReceiverOperatorTest.compareLists(
+                    expectedBroadcastInputs.get(i),
+                    
userFunction.getRuntimeContext().getBroadcastVariable(broadcastNames[i]));
+        }
+        output.collect(streamRecord);
+    }
+
+    @Override
+    public void processElement2(StreamRecord<Integer> streamRecord) {
+        for (int i = 0; i < broadcastNames.length; i++) {
+            BroadcastVariableReceiverOperatorTest.compareLists(
+                    expectedBroadcastInputs.get(i),
+                    
userFunction.getRuntimeContext().getBroadcastVariable(broadcastNames[i]));
+        }
+        output.collect(streamRecord);
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperatorTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperatorTest.java
new file mode 100644
index 0000000..73bcbbc
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperatorTest.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.iteration.config.IterationOptions;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
+import 
org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+import org.apache.flink.streaming.runtime.tasks.TwoInputStreamTask;
+import org.apache.flink.streaming.util.TestHarnessUtil;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Queue;
+import java.util.concurrent.ConcurrentLinkedQueue;
+
+/** Tests the ${@link TwoInputBroadcastWrapperOperator}. */
+public class TwoInputBroadcastWrapperOperatorTest {
+
+    @Rule public TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final String[] BROADCAST_NAMES = new String[] {"source1", 
"source2"};
+
+    private static final TypeInformation<?>[] TYPE_INFORMATIONS =
+            new TypeInformation[] {BasicTypeInfo.INT_TYPE_INFO, 
BasicTypeInfo.INT_TYPE_INFO};
+
+    private static final List<Integer> SOURCE_1 = Collections.singletonList(1);
+
+    private static final List<Integer> SOURCE_2 = Arrays.asList(1, 2, 3);
+
+    private static class MyRichFunction extends AbstractRichFunction {}
+
+    @Test
+    public void testProcessElements() throws Exception {
+        TwoInputStreamOperator<Integer, Integer, Integer> inputOp =
+                new TestTwoInputOp(
+                        new MyRichFunction(), BROADCAST_NAMES, 
Arrays.asList(SOURCE_1, SOURCE_2));
+        BroadcastWrapper<Integer> broadcastWrapper =
+                new BroadcastWrapper<>(BROADCAST_NAMES, TYPE_INFORMATIONS);
+        BroadcastWrapperOperatorFactory<Integer> wrapperFactory =
+                new BroadcastWrapperOperatorFactory<>(
+                        SimpleOperatorFactory.of(inputOp), broadcastWrapper);
+        OperatorID operatorId = new OperatorID();
+
+        try (StreamTaskMailboxTestHarness<Integer> harness =
+                new StreamTaskMailboxTestHarnessBuilder<>(
+                                TwoInputStreamTask::new, 
BasicTypeInfo.INT_TYPE_INFO)
+                        .addInput(BasicTypeInfo.INT_TYPE_INFO)
+                        .addInput(BasicTypeInfo.INT_TYPE_INFO)
+                        .setupOutputForSingletonOperatorChain(wrapperFactory, 
operatorId)
+                        .buildUnrestored()) {
+            harness.getStreamTask()
+                    .getEnvironment()
+                    .getTaskManagerInfo()
+                    .getConfiguration()
+                    .set(
+                            IterationOptions.DATA_CACHE_PATH,
+                            "file://" + 
tempFolder.newFolder().getAbsolutePath());
+            harness.getStreamTask().restore();
+            BroadcastContext.putBroadcastVariable(
+                    BROADCAST_NAMES[0] + "-" + 0, Tuple2.of(true, SOURCE_1));
+            BroadcastContext.putBroadcastVariable(
+                    BROADCAST_NAMES[1] + "-" + 0, Tuple2.of(true, SOURCE_2));
+
+            Queue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
+            for (int i = 0; i < 5; ++i) {
+                harness.processElement(new StreamRecord<>(i, 1000), 0);
+                harness.processElement(new StreamRecord<>(i, 1000), 1);
+                expectedOutput.add(new StreamRecord<>(i, 1000));
+                expectedOutput.add(new StreamRecord<>(i, 1000));
+            }
+            TestHarnessUtil.assertOutputEquals(
+                    "Output was not correct", expectedOutput, 
harness.getOutput());
+        }
+    }
+}
diff --git a/flink-ml-lib/src/test/resources/log4j2-test.properties 
b/flink-ml-lib/src/test/resources/log4j2-test.properties
new file mode 100644
index 0000000..835c2ec
--- /dev/null
+++ b/flink-ml-lib/src/test/resources/log4j2-test.properties
@@ -0,0 +1,28 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Set root logger level to OFF to not flood build logs
+# set manually to INFO for debugging purposes
+rootLogger.level = OFF
+rootLogger.appenderRef.test.ref = TestLogger
+
+appender.testlogger.name = TestLogger
+appender.testlogger.type = CONSOLE
+appender.testlogger.target = SYSTEM_ERR
+appender.testlogger.layout.type = PatternLayout
+appender.testlogger.layout.pattern = %-4r [%t] %-5p %c %x - %m%n

Reply via email to