This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new e843b300 [FLINK-31160] Support join/cogroup in 
BroadcastUtils.withBroadcastStream
e843b300 is described below

commit e843b300b47a1ee3446296a359528a6b39566eed
Author: Zhipeng Zhang <[email protected]>
AuthorDate: Wed Apr 19 17:05:11 2023 +0800

    [FLINK-31160] Support join/cogroup in BroadcastUtils.withBroadcastStream
    
    This closes #215.
---
 .../flink/ml/common/broadcast/BroadcastUtils.java  |  31 +--
 .../operator/AbstractBroadcastWrapperOperator.java | 264 ++++++++++++---------
 .../BroadcastVariableReceiverOperator.java         |  12 +-
 .../BroadcastVariableReceiverOperatorFactory.java  |   4 +-
 .../broadcast/operator/BroadcastWrapper.java       |  25 +-
 .../operator/OneInputBroadcastWrapperOperator.java |  22 +-
 .../operator/TwoInputBroadcastWrapperOperator.java |  31 ++-
 .../ml/common/broadcast/BroadcastUtilsTest.java    |  77 +++++-
 .../OneInputBroadcastWrapperOperatorTest.java      |   3 +-
 .../TwoInputBroadcastWrapperOperatorTest.java      |   7 +-
 10 files changed, 289 insertions(+), 187 deletions(-)

diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
index b6c7f7ca..9b4c7b82 100644
--- 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
@@ -32,7 +32,6 @@ import org.apache.flink.util.AbstractID;
 import org.apache.flink.util.Preconditions;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 import java.util.UUID;
@@ -41,7 +40,7 @@ import java.util.function.Function;
 /** Utility class to support withBroadcast in DataStream. */
 public class BroadcastUtils {
     /**
-     * supports withBroadcastStream in DataStream API. Broadcast data streams 
are available at all
+     * Supports withBroadcastStream in DataStream API. Broadcast data streams 
are available at all
      * parallel instances of an operator that extends {@code
      * org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator<OUT, 
? extends
      * org.apache.flink.api.common.functions.RichFunction>}. Users can access 
the broadcast
@@ -57,8 +56,10 @@ public class BroadcastUtils {
      * @param bcStreams map of the broadcast data streams, where the key is 
the name and the value
      *     is the corresponding data stream.
      * @param userDefinedFunction the user defined logic in which users can 
access the broadcast
-     *     data streams and produce the output data stream. Note that users 
can add only one
-     *     operator in this function, otherwise it raises an exception.
+     *     data streams and produce the output data stream. Note that though 
users can add more than
+     *     one operator in this logic, but only the operator that generates 
the result stream can
+     *     contain a rich function and access the broadcast variables. Other 
operators will
+     *     encounter NPE when accessing the broadcast variables.
      * @return the output data stream.
      */
     public static <OUT> DataStream<OUT> withBroadcastStream(
@@ -116,9 +117,9 @@ public class BroadcastUtils {
     }
 
     /**
-     * caches all broadcast iput data streams in static variables and returns 
the result multi-input
-     * stream operator. The result multi-input stream operator emits nothing 
and the only
-     * functionality of this operator is to cache all the input records in 
${@link
+     * Caches all broadcast input data streams in static variables and returns 
the result
+     * multi-input stream operator. The result multi-input stream operator 
emits nothing and the
+     * only functionality of this operator is to cache all the input records 
in ${@link
      * BroadcastContext}.
      *
      * @param env execution environment.
@@ -152,7 +153,7 @@ public class BroadcastUtils {
     }
 
     /**
-     * uses {@link DraftExecutionEnvironment} to execute the 
userDefinedFunction and returns the
+     * Uses {@link DraftExecutionEnvironment} to execute the 
userDefinedFunction and returns the
      * resultStream.
      *
      * @param env execution environment.
@@ -167,25 +168,13 @@ public class BroadcastUtils {
             List<DataStream<?>> inputList,
             String[] broadcastStreamNames,
             Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
-        TypeInformation<?>[] inTypes = new TypeInformation[inputList.size()];
-        for (int i = 0; i < inputList.size(); i++) {
-            inTypes[i] = inputList.get(i).getType();
-        }
-        // do not block all non-broadcast input edges by default.
-        boolean[] isBlocked = new boolean[inputList.size()];
-        Arrays.fill(isBlocked, false);
         DraftExecutionEnvironment draftEnv =
-                new DraftExecutionEnvironment(
-                        env, new BroadcastWrapper<>(broadcastStreamNames, 
inTypes, isBlocked));
-
+                new DraftExecutionEnvironment(env, new 
BroadcastWrapper<>(broadcastStreamNames));
         List<DataStream<?>> draftSources = new ArrayList<>();
         for (DataStream<?> dataStream : inputList) {
             draftSources.add(draftEnv.addDraftSource(dataStream, 
dataStream.getType()));
         }
         DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
-        Preconditions.checkState(
-                draftEnv.getStreamGraph(false).getStreamNodes().size() == 1 + 
inputList.size(),
-                "cannot add more than one operator in withBroadcastStream's 
lambda function.");
         draftEnv.copyToActualEnvironment();
         return draftEnv.getActualStream(draftOutStream.getId());
     }
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java
index 70fa91d6..33061408 100644
--- 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java
@@ -20,7 +20,6 @@ package org.apache.flink.ml.common.broadcast.operator;
 
 import org.apache.flink.api.common.functions.RichFunction;
 import org.apache.flink.api.common.operators.MailboxExecutor;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.core.memory.ManagedMemoryUseCase;
@@ -34,7 +33,7 @@ import org.apache.flink.metrics.groups.OperatorMetricGroup;
 import org.apache.flink.ml.common.broadcast.BroadcastContext;
 import org.apache.flink.ml.common.broadcast.BroadcastStreamingRuntimeContext;
 import org.apache.flink.ml.common.broadcast.typeinfo.CacheElement;
-import org.apache.flink.ml.common.broadcast.typeinfo.CacheElementTypeInfo;
+import org.apache.flink.ml.common.broadcast.typeinfo.CacheElementSerializer;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.jobgraph.OperatorID;
@@ -48,6 +47,8 @@ import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.runtime.util.NonClosingInputStreamDecorator;
 import org.apache.flink.runtime.util.NonClosingOutputStreamDecorator;
 import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.graph.StreamConfig.InputConfig;
+import org.apache.flink.streaming.api.graph.StreamConfig.NetworkInputConfig;
 import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
 import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
 import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
@@ -80,7 +81,14 @@ import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
 
-/** Base class for the broadcast wrapper operators. */
+/**
+ * Base class for the broadcast wrapper operators.
+ *
+ * <p>Note that not all instances of {@link AbstractBroadcastWrapperOperator} 
need to access the
+ * broadcast variables. If one instance of {@link 
AbstractBroadcastWrapperOperator} does not contain
+ * a rich function, then it can directly process the input without waiting for 
the broadcast
+ * variables.
+ */
 public abstract class AbstractBroadcastWrapperOperator<T, S extends 
StreamOperator<T>>
         implements StreamOperator<T>, 
StreamOperatorStateHandler.CheckpointedStreamOperator {
 
@@ -105,54 +113,57 @@ public abstract class AbstractBroadcastWrapperOperator<T, 
S extends StreamOperat
 
     protected transient InternalTimeServiceManager<?> timeServiceManager;
 
-    protected final MailboxExecutor mailboxExecutor;
+    // ---------------- context info for rich function ----------------
+    private MailboxExecutor mailboxExecutor;
 
-    /** variables specific for withBroadcast functionality. */
-    protected final String[] broadcastStreamNames;
+    private String[] broadcastStreamNames;
 
     /**
-     * whether each input is blocked. Inputs with broadcast variables can only 
process their input
+     * Whether each input is blocked. Inputs with broadcast variables can only 
process their input
      * records after broadcast variables are ready. One input is non-blocked 
if it can consume its
      * inputs (by caching) when broadcast variables are not ready. Otherwise 
it has to block the
      * processing and wait until the broadcast variables are ready to be 
accessed.
      */
-    protected final boolean[] isBlocked;
-
-    /** type information of each input. */
-    protected final TypeInformation<?>[] inTypes;
+    private boolean[] isBlocked;
 
-    /** whether all broadcast variables of this operator are ready. */
-    protected boolean broadcastVariablesReady;
+    /** Type serializer of each input. */
+    private TypeSerializer<?>[] inTypeSerializers;
 
-    /** index of this subtask. */
-    protected final transient int indexOfSubtask;
+    /** Whether all broadcast variables of this operator are ready. */
+    private boolean broadcastVariablesReady;
+    /** Index of this subtask. */
+    protected transient int indexOfSubtask;
 
-    /** number of the inputs of this operator. */
-    protected final int numInputs;
+    /** Number of the inputs of this operator. */
+    protected int numInputs;
 
-    /** runtimeContext of the rich function in wrapped operator. */
-    BroadcastStreamingRuntimeContext wrappedOperatorRuntimeContext;
+    /** RuntimeContext of the rich function in wrapped operator. */
+    private BroadcastStreamingRuntimeContext wrappedOperatorRuntimeContext;
 
     /**
-     * path of the file used to stored the cached records. It could be local 
file system or remote
+     * Path of the file used to store the cached records. It could be local 
file system or remote
      * file system.
      */
-    private final Path basePath;
+    private Path basePath;
 
     /** DataCacheWriter for each input. */
     @SuppressWarnings("rawtypes")
-    protected DataCacheWriter[] dataCacheWriters;
+    private DataCacheWriter[] dataCacheWriters;
 
-    /** whether each input has pending elements. */
-    protected boolean[] hasPendingElements;
+    /** Whether each input has pending elements. */
+    private boolean[] hasPendingElements;
+
+    /**
+     * Whether this operator has a rich function and needs to access the 
broadcast variable. If yes,
+     * it cannot process elements until the broadcast variables are ready.
+     */
+    private final boolean hasRichFunction;
 
     @SuppressWarnings({"unchecked", "rawtypes"})
     AbstractBroadcastWrapperOperator(
             StreamOperatorParameters<T> parameters,
             StreamOperatorFactory<T> operatorFactory,
-            String[] broadcastStreamNames,
-            TypeInformation<?>[] inTypes,
-            boolean[] isBlocked) {
+            String[] broadcastStreamNames) {
         this.parameters = Objects.requireNonNull(parameters);
         this.streamConfig = 
Objects.requireNonNull(parameters.getStreamConfig());
         this.containingTask = 
Objects.requireNonNull(parameters.getContainingTask());
@@ -169,13 +180,13 @@ public abstract class AbstractBroadcastWrapperOperator<T, 
S extends StreamOperat
                                         
parameters.getOperatorEventDispatcher())
                                 .f0;
 
-        boolean hasRichFunction =
+        this.hasRichFunction =
                 wrappedOperator instanceof AbstractUdfStreamOperator
                         && ((AbstractUdfStreamOperator) 
wrappedOperator).getUserFunction()
                                 instanceof RichFunction;
 
         if (hasRichFunction) {
-            wrappedOperatorRuntimeContext =
+            this.wrappedOperatorRuntimeContext =
                     new BroadcastStreamingRuntimeContext(
                             containingTask.getEnvironment(),
                             
containingTask.getEnvironment().getAccumulatorRegistry().getUserMap(),
@@ -188,43 +199,58 @@ public abstract class AbstractBroadcastWrapperOperator<T, 
S extends StreamOperat
 
             ((RichFunction) ((AbstractUdfStreamOperator) 
wrappedOperator).getUserFunction())
                     .setRuntimeContext(wrappedOperatorRuntimeContext);
-        } else {
-            throw new RuntimeException(
-                    "The operator is not a instance of "
-                            + AbstractUdfStreamOperator.class.getSimpleName()
-                            + " that contains a "
-                            + RichFunction.class.getSimpleName());
-        }
 
-        this.mailboxExecutor =
-                
containingTask.getMailboxExecutorFactory().createExecutor(TaskMailbox.MIN_PRIORITY);
-        // variables specific for withBroadcast functionality.
-        this.broadcastStreamNames = broadcastStreamNames;
-        this.isBlocked = isBlocked;
-        this.inTypes = inTypes;
-        this.broadcastVariablesReady = false;
-        this.indexOfSubtask = containingTask.getIndexInSubtaskGroup();
-        this.numInputs = inTypes.length;
-
-        // puts in mailboxExecutor
-        for (String name : broadcastStreamNames) {
-            BroadcastContext.putMailBoxExecutor(name + "-" + indexOfSubtask, 
mailboxExecutor);
-        }
+            this.mailboxExecutor =
+                    containingTask
+                            .getMailboxExecutorFactory()
+                            .createExecutor(TaskMailbox.MIN_PRIORITY);
+
+            this.indexOfSubtask = containingTask.getIndexInSubtaskGroup();
+
+            // Puts in mailboxExecutor.
+            for (String name : broadcastStreamNames) {
+                BroadcastContext.putMailBoxExecutor(name + "-" + 
indexOfSubtask, mailboxExecutor);
+            }
+
+            this.broadcastStreamNames = broadcastStreamNames;
+
+            InputConfig[] inputConfigs =
+                    
streamConfig.getInputs(containingTask.getUserCodeClassLoader());
+
+            int numNetworkInputs = 0;
+            while (numNetworkInputs < inputConfigs.length
+                    && inputConfigs[numNetworkInputs] instanceof 
NetworkInputConfig) {
+                numNetworkInputs++;
+            }
+            this.numInputs = numNetworkInputs;
+
+            this.isBlocked = new boolean[numInputs];
+            Arrays.fill(isBlocked, false);
 
-        basePath =
-                OperatorUtils.getDataCachePath(
-                        
containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(),
-                        containingTask
-                                .getEnvironment()
-                                .getIOManager()
-                                .getSpillingDirectoriesPaths());
-        dataCacheWriters = new DataCacheWriter[numInputs];
-        hasPendingElements = new boolean[numInputs];
-        Arrays.fill(hasPendingElements, true);
+            this.inTypeSerializers = new TypeSerializer[numInputs];
+            for (int i = 0; i < numInputs; i++) {
+                inTypeSerializers[i] =
+                        streamConfig.getTypeSerializerIn(
+                                i, containingTask.getUserCodeClassLoader());
+            }
+
+            this.broadcastVariablesReady = false;
+
+            this.basePath =
+                    OperatorUtils.getDataCachePath(
+                            
containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(),
+                            containingTask
+                                    .getEnvironment()
+                                    .getIOManager()
+                                    .getSpillingDirectoriesPaths());
+            this.dataCacheWriters = new DataCacheWriter[numInputs];
+            this.hasPendingElements = new boolean[numInputs];
+            Arrays.fill(hasPendingElements, true);
+        }
     }
 
     /**
-     * checks whether all of broadcast variables are ready. Besides it 
maintains a state
+     * Checks whether all broadcast variables are ready. Besides, it maintains 
a state
      * {broadcastVariablesReady} to avoiding invoking {@code 
BroadcastContext.isCacheFinished(...)}
      * repeatedly. Finally, it sets broadcast variables for 
{wrappedOperatorRuntimeContext} if the
      * broadcast variables are ready.
@@ -269,7 +295,7 @@ public abstract class AbstractBroadcastWrapperOperator<T, S 
extends StreamOperat
     }
 
     /**
-     * extracts common processing logic in subclasses' processing elements.
+     * Extracts common processing logic in subclasses' processing elements.
      *
      * @param streamRecord the input record.
      * @param inputIndex input id, starts from zero.
@@ -277,6 +303,8 @@ public abstract class AbstractBroadcastWrapperOperator<T, S 
extends StreamOperat
      *     operator.processElement(...).
      * @param watermarkConsumer the consumer function of WaterMark, i.e.,
      *     operator.processWatermark(...).
+     * @param keyContextSetter the consumer function of setting key context, 
i.e.,
+     *     operator.setKeyContext(...).
      * @throws Exception possible exception.
      */
     @SuppressWarnings({"rawtypes", "unchecked"})
@@ -284,32 +312,31 @@ public abstract class AbstractBroadcastWrapperOperator<T, 
S extends StreamOperat
             StreamRecord streamRecord,
             int inputIndex,
             ThrowingConsumer<StreamRecord, Exception> elementConsumer,
-            ThrowingConsumer<Watermark, Exception> watermarkConsumer)
+            ThrowingConsumer<Watermark, Exception> watermarkConsumer,
+            ThrowingConsumer<StreamRecord, Exception> keyContextSetter)
             throws Exception {
-        if (!isBlocked[inputIndex]) {
-            if (areBroadcastVariablesReady()) {
-                if (hasPendingElements[inputIndex]) {
-                    processPendingElementsAndWatermarks(
-                            inputIndex, elementConsumer, watermarkConsumer);
-                    hasPendingElements[inputIndex] = false;
-                }
-                elementConsumer.accept(streamRecord);
-
-            } else {
-                dataCacheWriters[inputIndex].addRecord(
-                        CacheElement.newRecord(streamRecord.getValue()));
-            }
-
-        } else {
+        if (!hasRichFunction) {
+            elementConsumer.accept(streamRecord);
+        } else if (isBlocked[inputIndex]) {
             while (!areBroadcastVariablesReady()) {
                 mailboxExecutor.yield();
             }
             elementConsumer.accept(streamRecord);
+        } else if (!areBroadcastVariablesReady()) {
+            
dataCacheWriters[inputIndex].addRecord(CacheElement.newRecord(streamRecord.getValue()));
+        } else {
+            if (hasPendingElements[inputIndex]) {
+                processPendingElementsAndWatermarks(
+                        inputIndex, elementConsumer, watermarkConsumer, 
keyContextSetter);
+                hasPendingElements[inputIndex] = false;
+            }
+            keyContextSetter.accept(streamRecord);
+            elementConsumer.accept(streamRecord);
         }
     }
 
     /**
-     * extracts common processing logic in subclasses' processing watermarks.
+     * Extracts common processing logic in subclasses' processing watermarks.
      *
      * @param watermark the input watermark.
      * @param inputIndex input id, starts from zero.
@@ -317,6 +344,8 @@ public abstract class AbstractBroadcastWrapperOperator<T, S 
extends StreamOperat
      *     operator.processElement(...).
      * @param watermarkConsumer the consumer function of WaterMark, i.e.,
      *     operator.processWatermark(...).
+     * @param keyContextSetter the consumer function of setting key context, 
i.e.,
+     *     operator.setKeyContext(...).
      * @throws Exception possible exception.
      */
     @SuppressWarnings({"rawtypes", "unchecked"})
@@ -324,83 +353,93 @@ public abstract class AbstractBroadcastWrapperOperator<T, 
S extends StreamOperat
             Watermark watermark,
             int inputIndex,
             ThrowingConsumer<StreamRecord, Exception> elementConsumer,
-            ThrowingConsumer<Watermark, Exception> watermarkConsumer)
+            ThrowingConsumer<Watermark, Exception> watermarkConsumer,
+            ThrowingConsumer<StreamRecord, Exception> keyContextSetter)
             throws Exception {
-        if (!isBlocked[inputIndex]) {
-            if (areBroadcastVariablesReady()) {
-                if (hasPendingElements[inputIndex]) {
-                    processPendingElementsAndWatermarks(
-                            inputIndex, elementConsumer, watermarkConsumer);
-                    hasPendingElements[inputIndex] = false;
-                }
-                watermarkConsumer.accept(watermark);
-
-            } else {
-                dataCacheWriters[inputIndex].addRecord(
-                        CacheElement.newWatermark(watermark.getTimestamp()));
-            }
-
-        } else {
+        if (!hasRichFunction) {
+            watermarkConsumer.accept(watermark);
+        } else if (isBlocked[inputIndex]) {
             while (!areBroadcastVariablesReady()) {
                 mailboxExecutor.yield();
             }
             watermarkConsumer.accept(watermark);
+        } else if (!areBroadcastVariablesReady()) {
+            dataCacheWriters[inputIndex].addRecord(
+                    CacheElement.newWatermark(watermark.getTimestamp()));
+        } else {
+            if (hasPendingElements[inputIndex]) {
+                processPendingElementsAndWatermarks(
+                        inputIndex, elementConsumer, watermarkConsumer, 
keyContextSetter);
+                hasPendingElements[inputIndex] = false;
+            }
+            watermarkConsumer.accept(watermark);
         }
     }
 
     /**
-     * extracts common processing logic in subclasses' endInput(...).
+     * Extracts common processing logic in subclasses' endInput(...).
      *
      * @param inputIndex input id, starts from zero.
      * @param elementConsumer the consumer function of StreamRecord, i.e.,
      *     operator.processElement(...).
      * @param watermarkConsumer the consumer function of WaterMark, i.e.,
      *     operator.processWatermark(...).
+     * @param keyContextSetter the consumer function of setting key context, 
i.e.,
+     *     operator.setKeyContext(...).
      * @throws Exception possible exception.
      */
     @SuppressWarnings("rawtypes")
     protected void endInputX(
             int inputIndex,
             ThrowingConsumer<StreamRecord, Exception> elementConsumer,
-            ThrowingConsumer<Watermark, Exception> watermarkConsumer)
+            ThrowingConsumer<Watermark, Exception> watermarkConsumer,
+            ThrowingConsumer<StreamRecord, Exception> keyContextSetter)
             throws Exception {
+        if (!hasRichFunction) {
+            return;
+        }
         while (!areBroadcastVariablesReady()) {
             mailboxExecutor.yield();
         }
         if (hasPendingElements[inputIndex]) {
-            processPendingElementsAndWatermarks(inputIndex, elementConsumer, 
watermarkConsumer);
+            processPendingElementsAndWatermarks(
+                    inputIndex, elementConsumer, watermarkConsumer, 
keyContextSetter);
             hasPendingElements[inputIndex] = false;
         }
     }
 
     /**
-     * processes the pending elements that are cached by {@link 
DataCacheWriter}.
+     * Processes the pending elements that are cached by {@link 
DataCacheWriter}.
      *
      * @param inputIndex input id, starts from zero.
      * @param elementConsumer the consumer function of StreamRecord, i.e.,
      *     operator.processElement(...).
      * @param watermarkConsumer the consumer function of WaterMark, i.e.,
      *     operator.processWatermark(...).
+     * @param keyContextSetter the consumer function of setting key context, 
i.e.,
+     *     operator.setKeyContext(...).
      * @throws Exception possible exception.
      */
     @SuppressWarnings({"rawtypes", "unchecked"})
     private void processPendingElementsAndWatermarks(
             int inputIndex,
             ThrowingConsumer<StreamRecord, Exception> elementConsumer,
-            ThrowingConsumer<Watermark, Exception> watermarkConsumer)
+            ThrowingConsumer<Watermark, Exception> watermarkConsumer,
+            ThrowingConsumer<StreamRecord, Exception> keyContextSetter)
             throws Exception {
         List<Segment> pendingSegments = 
dataCacheWriters[inputIndex].getSegments();
         if (pendingSegments.size() != 0) {
             DataCacheReader dataCacheReader =
                     new DataCacheReader<>(
-                            new CacheElementTypeInfo<>(inTypes[inputIndex])
-                                    
.createSerializer(containingTask.getExecutionConfig()),
+                            new 
CacheElementSerializer<>(inTypeSerializers[inputIndex]),
                             pendingSegments);
             while (dataCacheReader.hasNext()) {
                 CacheElement cacheElement = (CacheElement) 
dataCacheReader.next();
                 switch (cacheElement.getType()) {
                     case RECORD:
-                        elementConsumer.accept(new 
StreamRecord(cacheElement.getRecord()));
+                        StreamRecord record = new 
StreamRecord(cacheElement.getRecord());
+                        keyContextSetter.accept(record);
+                        elementConsumer.accept(record);
                         break;
                     case WATERMARK:
                         watermarkConsumer.accept(new 
Watermark(cacheElement.getWatermark()));
@@ -410,6 +449,7 @@ public abstract class AbstractBroadcastWrapperOperator<T, S 
extends StreamOperat
                                 "Unsupported CacheElement type: " + 
cacheElement.getType());
                 }
             }
+            dataCacheWriters[inputIndex].clear();
         }
     }
 
@@ -421,6 +461,9 @@ public abstract class AbstractBroadcastWrapperOperator<T, S 
extends StreamOperat
     @Override
     public void close() throws Exception {
         wrappedOperator.close();
+        if (!hasRichFunction) {
+            return;
+        }
         for (String name : broadcastStreamNames) {
             BroadcastContext.remove(name + "-" + indexOfSubtask);
         }
@@ -468,8 +511,6 @@ public abstract class AbstractBroadcastWrapperOperator<T, S 
extends StreamOperat
 
         timeServiceManager = 
streamOperatorStateContext.internalTimerServiceManager();
 
-        broadcastVariablesReady = false;
-
         wrappedOperator.initializeState(
                 (operatorID,
                         operatorClassName,
@@ -514,12 +555,14 @@ public abstract class AbstractBroadcastWrapperOperator<T, 
S extends StreamOperat
                         
stateInitializationContext.getRawOperatorStateInputs().iterator());
         Preconditions.checkState(
                 inputs.size() < 2, "The input from raw operator state should 
be one or zero.");
+        if (!hasRichFunction) {
+            return;
+        }
         if (inputs.size() == 0) {
             for (int i = 0; i < numInputs; i++) {
                 dataCacheWriters[i] =
                         new DataCacheWriter(
-                                new CacheElementTypeInfo<>(inTypes[i])
-                                        
.createSerializer(containingTask.getExecutionConfig()),
+                                new 
CacheElementSerializer(inTypeSerializers[i]),
                                 basePath.getFileSystem(),
                                 OperatorUtils.createDataCacheFileGenerator(
                                         basePath, "cache", 
streamConfig.getOperatorID()));
@@ -538,8 +581,7 @@ public abstract class AbstractBroadcastWrapperOperator<T, S 
extends StreamOperat
                                         basePath, "cache", 
streamConfig.getOperatorID()));
                 dataCacheWriters[i] =
                         new DataCacheWriter(
-                                new CacheElementTypeInfo<>(inTypes[i])
-                                        
.createSerializer(containingTask.getExecutionConfig()),
+                                new 
CacheElementSerializer(inTypeSerializers[i]),
                                 basePath.getFileSystem(),
                                 OperatorUtils.createDataCacheFileGenerator(
                                         basePath, "cache", 
streamConfig.getOperatorID()),
@@ -555,6 +597,10 @@ public abstract class AbstractBroadcastWrapperOperator<T, 
S extends StreamOperat
             ((CheckpointedStreamOperator) 
wrappedOperator).snapshotState(stateSnapshotContext);
         }
 
+        if (!hasRichFunction) {
+            return;
+        }
+
         OperatorStateCheckpointOutputStream checkpointOutputStream =
                 stateSnapshotContext.getRawOperatorStateOutput();
         checkpointOutputStream.startNewPartition();
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperator.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperator.java
index e101b8fa..7a65ed02 100644
--- 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperator.java
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperator.java
@@ -45,24 +45,24 @@ import java.util.List;
 public class BroadcastVariableReceiverOperator<OUT> extends 
AbstractStreamOperatorV2<OUT>
         implements MultipleInputStreamOperator<OUT>, BoundedMultiInput, 
Serializable {
 
-    /** names of the broadcast data streams. */
+    /** Names of the broadcast data streams. */
     private final String[] broadcastStreamNames;
 
-    /** output types of input data streams. */
+    /** Output types of input data streams. */
     private final TypeInformation<?>[] inTypes;
 
-    /** input list of the multi-input operator. */
+    /** Input list of the multi-input operator. */
     @SuppressWarnings("rawtypes")
     private final List<Input> inputList;
 
-    /** whether each broadcast input has finished. */
+    /** Whether each broadcast input has finished. */
     private boolean[] cachesReady;
 
-    /** state storage of the broadcast inputs. */
+    /** State storage of the broadcast inputs. */
     @SuppressWarnings("rawtypes")
     private ListState[] cacheStates;
 
-    /** cacheReady state storage of the broadcast inputs. */
+    /** CacheReady state storage of the broadcast inputs. */
     private ListState<Boolean>[] cacheReadyStates;
 
     @SuppressWarnings({"rawtypes", "unchecked"})
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorFactory.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorFactory.java
index 75721c77..2ff7762e 100644
--- 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorFactory.java
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorFactory.java
@@ -29,10 +29,10 @@ import java.io.Serializable;
 public class BroadcastVariableReceiverOperatorFactory<OUT>
         extends AbstractStreamOperatorFactory<OUT> implements Serializable {
 
-    /** names of the broadcast data streams. */
+    /** Names of the broadcast data streams. */
     private final String[] broadcastNames;
 
-    /** types of the broadcast data streams. */
+    /** Types of the broadcast data streams. */
     private final TypeInformation<?>[] inTypes;
 
     public BroadcastVariableReceiverOperatorFactory(
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
index 2a18c855..bf17cb8d 100644
--- 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
@@ -18,7 +18,6 @@
 
 package org.apache.flink.ml.common.broadcast.operator;
 
-import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.iteration.operator.OperatorWrapper;
@@ -29,31 +28,15 @@ import 
org.apache.flink.streaming.api.operators.StreamOperatorParameters;
 import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
 import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 import org.apache.flink.util.OutputTag;
-import org.apache.flink.util.Preconditions;
 
 /** The operator wrapper for {@link AbstractBroadcastWrapperOperator}. */
 public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
 
-    /** names of the broadcast data streams. */
+    /** Names of the broadcast data streams. */
     private final String[] broadcastStreamNames;
 
-    /** types of input data streams. */
-    private final TypeInformation<?>[] inTypes;
-
-    /** whether each input is blocked or not. */
-    private final boolean[] isBlocked;
-
-    @VisibleForTesting
-    public BroadcastWrapper(String[] broadcastStreamNames, 
TypeInformation<?>[] inTypes) {
-        this(broadcastStreamNames, inTypes, new boolean[inTypes.length]);
-    }
-
-    public BroadcastWrapper(
-            String[] broadcastStreamNames, TypeInformation<?>[] inTypes, 
boolean[] isBlocked) {
-        Preconditions.checkArgument(inTypes.length == isBlocked.length);
+    public BroadcastWrapper(String[] broadcastStreamNames) {
         this.broadcastStreamNames = broadcastStreamNames;
-        this.inTypes = inTypes;
-        this.isBlocked = isBlocked;
     }
 
     @Override
@@ -64,10 +47,10 @@ public class BroadcastWrapper<T> implements 
OperatorWrapper<T, T> {
                 
operatorFactory.getStreamOperatorClass(getClass().getClassLoader());
         if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) {
             return new OneInputBroadcastWrapperOperator<>(
-                    operatorParameters, operatorFactory, broadcastStreamNames, 
inTypes, isBlocked);
+                    operatorParameters, operatorFactory, broadcastStreamNames);
         } else if 
(TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) {
             return new TwoInputBroadcastWrapperOperator<>(
-                    operatorParameters, operatorFactory, broadcastStreamNames, 
inTypes, isBlocked);
+                    operatorParameters, operatorFactory, broadcastStreamNames);
         } else {
             throw new UnsupportedOperationException(
                     "Unsupported operator class for with-broadcast wrapper: " 
+ operatorClass);
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
index f1ffe000..8039005e 100644
--- 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
@@ -18,7 +18,6 @@
 
 package org.apache.flink.ml.common.broadcast.operator;
 
-import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.streaming.api.operators.BoundedOneInput;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
@@ -37,10 +36,8 @@ public class OneInputBroadcastWrapperOperator<IN, OUT>
     OneInputBroadcastWrapperOperator(
             StreamOperatorParameters<OUT> parameters,
             StreamOperatorFactory<OUT> operatorFactory,
-            String[] broadcastStreamNames,
-            TypeInformation<?>[] inTypes,
-            boolean[] isBlocking) {
-        super(parameters, operatorFactory, broadcastStreamNames, inTypes, 
isBlocking);
+            String[] broadcastStreamNames) {
+        super(parameters, operatorFactory, broadcastStreamNames);
     }
 
     @Override
@@ -49,12 +46,17 @@ public class OneInputBroadcastWrapperOperator<IN, OUT>
                 streamRecord,
                 0,
                 wrappedOperator::processElement,
-                wrappedOperator::processWatermark);
+                wrappedOperator::processWatermark,
+                wrappedOperator::setKeyContextElement);
     }
 
     @Override
     public void endInput() throws Exception {
-        endInputX(0, wrappedOperator::processElement, 
wrappedOperator::processWatermark);
+        endInputX(
+                0,
+                wrappedOperator::processElement,
+                wrappedOperator::processWatermark,
+                wrappedOperator::setKeyContextElement);
         OperatorUtils.processOperatorOrUdfIfSatisfy(
                 wrappedOperator, BoundedOneInput.class, 
BoundedOneInput::endInput);
     }
@@ -62,7 +64,11 @@ public class OneInputBroadcastWrapperOperator<IN, OUT>
     @Override
     public void processWatermark(Watermark watermark) throws Exception {
         processWatermarkX(
-                watermark, 0, wrappedOperator::processElement, 
wrappedOperator::processWatermark);
+                watermark,
+                0,
+                wrappedOperator::processElement,
+                wrappedOperator::processWatermark,
+                wrappedOperator::setKeyContextElement);
     }
 
     @Override
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperator.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperator.java
index 07871d47..4d4d468c 100644
--- 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperator.java
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperator.java
@@ -18,7 +18,6 @@
 
 package org.apache.flink.ml.common.broadcast.operator;
 
-import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.streaming.api.operators.BoundedMultiInput;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
@@ -37,10 +36,8 @@ public class TwoInputBroadcastWrapperOperator<IN1, IN2, OUT>
     TwoInputBroadcastWrapperOperator(
             StreamOperatorParameters<OUT> parameters,
             StreamOperatorFactory<OUT> operatorFactory,
-            String[] broadcastStreamNames,
-            TypeInformation<?>[] inTypes,
-            boolean[] isBlocking) {
-        super(parameters, operatorFactory, broadcastStreamNames, inTypes, 
isBlocking);
+            String[] broadcastStreamNames) {
+        super(parameters, operatorFactory, broadcastStreamNames);
     }
 
     @Override
@@ -49,7 +46,8 @@ public class TwoInputBroadcastWrapperOperator<IN1, IN2, OUT>
                 streamRecord,
                 0,
                 wrappedOperator::processElement1,
-                wrappedOperator::processWatermark1);
+                wrappedOperator::processWatermark1,
+                wrappedOperator::setKeyContextElement1);
     }
 
     @Override
@@ -58,7 +56,8 @@ public class TwoInputBroadcastWrapperOperator<IN1, IN2, OUT>
                 streamRecord,
                 1,
                 wrappedOperator::processElement2,
-                wrappedOperator::processWatermark2);
+                wrappedOperator::processWatermark2,
+                wrappedOperator::setKeyContextElement2);
     }
 
     @Override
@@ -67,12 +66,14 @@ public class TwoInputBroadcastWrapperOperator<IN1, IN2, OUT>
             endInputX(
                     inputId - 1,
                     wrappedOperator::processElement1,
-                    wrappedOperator::processWatermark1);
+                    wrappedOperator::processWatermark1,
+                    wrappedOperator::setKeyContextElement1);
         } else {
             endInputX(
                     inputId - 1,
                     wrappedOperator::processElement2,
-                    wrappedOperator::processWatermark2);
+                    wrappedOperator::processWatermark2,
+                    wrappedOperator::setKeyContextElement2);
         }
         OperatorUtils.processOperatorOrUdfIfSatisfy(
                 wrappedOperator,
@@ -83,13 +84,21 @@ public class TwoInputBroadcastWrapperOperator<IN1, IN2, OUT>
     @Override
     public void processWatermark1(Watermark watermark) throws Exception {
         processWatermarkX(
-                watermark, 0, wrappedOperator::processElement1, 
wrappedOperator::processWatermark1);
+                watermark,
+                0,
+                wrappedOperator::processElement1,
+                wrappedOperator::processWatermark1,
+                wrappedOperator::setKeyContextElement1);
     }
 
     @Override
     public void processWatermark2(Watermark watermark) throws Exception {
         processWatermarkX(
-                watermark, 1, wrappedOperator::processElement2, 
wrappedOperator::processWatermark2);
+                watermark,
+                1,
+                wrappedOperator::processElement2,
+                wrappedOperator::processWatermark2,
+                wrappedOperator::setKeyContextElement2);
     }
 
     @Override
diff --git 
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/broadcast/BroadcastUtilsTest.java
 
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/broadcast/BroadcastUtilsTest.java
index 1eb8b78f..e6fbbfbb 100644
--- 
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/broadcast/BroadcastUtilsTest.java
+++ 
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/broadcast/BroadcastUtilsTest.java
@@ -19,13 +19,17 @@
 package org.apache.flink.ml.common.broadcast;
 
 import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.functions.RichJoinFunction;
 import org.apache.flink.api.common.restartstrategy.RestartStrategies;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.RestOptions;
 import org.apache.flink.iteration.config.IterationOptions;
+import 
org.apache.flink.ml.common.broadcast.operator.BroadcastVariableReceiverOperatorTest;
 import org.apache.flink.ml.common.broadcast.operator.TestOneInputOp;
 import org.apache.flink.ml.common.broadcast.operator.TestTwoInputOp;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
 import org.apache.flink.ml.util.TestUtils;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.minicluster.MiniCluster;
@@ -42,6 +46,7 @@ import org.junit.rules.TemporaryFolder;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.function.Function;
@@ -99,6 +104,52 @@ public class BroadcastUtilsTest {
         }
     }
 
+    @Test
+    public void testBroadcastWithJoin() throws Exception {
+        try (MiniCluster miniCluster = new 
MiniCluster(createMiniClusterConfiguration())) {
+            miniCluster.start();
+            JobGraph jobGraph = getBroadcastWithJoinJobGraph();
+            miniCluster.executeJobBlocking(jobGraph);
+        }
+    }
+
+    private JobGraph getBroadcastWithJoinJobGraph() {
+        StreamExecutionEnvironment env = TestUtils.getExecutionEnvironment();
+        env.setRestartStrategy(RestartStrategies.fallBackRestart());
+        env.enableCheckpointing(500, CheckpointingMode.EXACTLY_ONCE);
+        env.setParallelism(NUM_SLOT * NUM_TM);
+
+        DataStream<Integer> source1 = env.addSource(new 
TestSource(NUM_RECORDS_PER_PARTITION));
+        DataStream<Integer> source2 = env.addSource(new 
TestSource(NUM_RECORDS_PER_PARTITION));
+
+        List<Integer> expectedNumSequence = new ArrayList<>(NUM_TM * NUM_SLOT 
* 10);
+        for (int i = 0; i < NUM_TM * NUM_SLOT * NUM_RECORDS_PER_PARTITION; 
i++) {
+            expectedNumSequence.add(i);
+        }
+
+        List<Integer> expectedBroadcastVariable = expectedNumSequence;
+
+        List<DataStream<?>> inputList = Arrays.asList(source1, source2);
+        DataStream<Integer> result =
+                BroadcastUtils.withBroadcastStream(
+                        inputList,
+                        Collections.singletonMap(BROADCAST_NAMES[0], source1),
+                        inputs -> {
+                            DataStream<Integer> input1 = (DataStream<Integer>) 
inputs.get(0);
+                            DataStream<Integer> input2 = (DataStream<Integer>) 
inputs.get(1);
+                            return input1.join(input2)
+                                    .where((KeySelector<Integer, Integer>) x0 
-> x0)
+                                    .equalTo((KeySelector<Integer, Integer>) 
x1 -> x1)
+                                    .window(EndOfStreamWindows.get())
+                                    .apply(
+                                            new 
RichJoinFunctionWithBroadcastVariable(
+                                                    BROADCAST_NAMES[0], 
expectedBroadcastVariable));
+                        });
+
+        result.addSink(new 
TestSink(expectedNumSequence)).getTransformation().setParallelism(1);
+        return env.getStreamGraph().getJobGraph();
+    }
+
     private JobGraph getJobGraph(int numNonBroadcastInputs) {
         StreamExecutionEnvironment env = TestUtils.getExecutionEnvironment();
         env.setRestartStrategy(RestartStrategies.fallBackRestart());
@@ -112,7 +163,7 @@ public class BroadcastUtilsTest {
         bcStreamsMap.put(BROADCAST_NAMES[1], source2);
 
         List<DataStream<?>> inputList = new ArrayList<>(1);
-        // create a deadlock.
+        // Creates a deadlock.
         inputList.add(source1);
         for (int i = 0; i < numNonBroadcastInputs - 1; i++) {
             inputList.add(env.addSource(new 
TestSource(NUM_RECORDS_PER_PARTITION)));
@@ -167,4 +218,28 @@ public class BroadcastUtilsTest {
         }
         return null;
     }
+
+    private static class RichJoinFunctionWithBroadcastVariable
+            extends RichJoinFunction<Integer, Integer, Integer> {
+        private final String broadcastVariableName;
+        private final List<Integer> expectedBroadcastVariable;
+        // Stores the received broadcast variable.
+        private List<Integer> broadcastVariable;
+
+        public RichJoinFunctionWithBroadcastVariable(
+                String broadcastVariableName, List<Integer> 
expectedBroadcastVariable) {
+            this.broadcastVariableName = broadcastVariableName;
+            this.expectedBroadcastVariable = expectedBroadcastVariable;
+        }
+
+        @Override
+        public Integer join(Integer first, Integer second) throws Exception {
+            if (broadcastVariable == null) {
+                broadcastVariable = 
getRuntimeContext().getBroadcastVariable(broadcastVariableName);
+                BroadcastVariableReceiverOperatorTest.compareLists(
+                        expectedBroadcastVariable, broadcastVariable);
+            }
+            return first;
+        }
+    }
 }
diff --git 
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperatorTest.java
 
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperatorTest.java
index 7d64746e..57b47047 100644
--- 
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperatorTest.java
+++ 
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperatorTest.java
@@ -64,8 +64,7 @@ public class OneInputBroadcastWrapperOperatorTest {
         OneInputStreamOperator<Integer, Integer> inputOp =
                 new TestOneInputOp(
                         new MyRichFunction(), BROADCAST_NAMES, 
Arrays.asList(SOURCE_1, SOURCE_2));
-        BroadcastWrapper<Integer> broadcastWrapper =
-                new BroadcastWrapper<>(BROADCAST_NAMES, TYPE_INFORMATIONS);
+        BroadcastWrapper<Integer> broadcastWrapper = new 
BroadcastWrapper<>(BROADCAST_NAMES);
         BroadcastWrapperOperatorFactory<Integer> wrapperFactory =
                 new BroadcastWrapperOperatorFactory<>(
                         SimpleOperatorFactory.of(inputOp), broadcastWrapper);
diff --git 
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperatorTest.java
 
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperatorTest.java
index 73bcbbca..ab517d7c 100644
--- 
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperatorTest.java
+++ 
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/broadcast/operator/TwoInputBroadcastWrapperOperatorTest.java
@@ -20,7 +20,6 @@ package org.apache.flink.ml.common.broadcast.operator;
 
 import org.apache.flink.api.common.functions.AbstractRichFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.iteration.config.IterationOptions;
 import org.apache.flink.ml.common.broadcast.BroadcastContext;
@@ -50,9 +49,6 @@ public class TwoInputBroadcastWrapperOperatorTest {
 
     private static final String[] BROADCAST_NAMES = new String[] {"source1", 
"source2"};
 
-    private static final TypeInformation<?>[] TYPE_INFORMATIONS =
-            new TypeInformation[] {BasicTypeInfo.INT_TYPE_INFO, 
BasicTypeInfo.INT_TYPE_INFO};
-
     private static final List<Integer> SOURCE_1 = Collections.singletonList(1);
 
     private static final List<Integer> SOURCE_2 = Arrays.asList(1, 2, 3);
@@ -64,8 +60,7 @@ public class TwoInputBroadcastWrapperOperatorTest {
         TwoInputStreamOperator<Integer, Integer, Integer> inputOp =
                 new TestTwoInputOp(
                         new MyRichFunction(), BROADCAST_NAMES, 
Arrays.asList(SOURCE_1, SOURCE_2));
-        BroadcastWrapper<Integer> broadcastWrapper =
-                new BroadcastWrapper<>(BROADCAST_NAMES, TYPE_INFORMATIONS);
+        BroadcastWrapper<Integer> broadcastWrapper = new 
BroadcastWrapper<>(BROADCAST_NAMES);
         BroadcastWrapperOperatorFactory<Integer> wrapperFactory =
                 new BroadcastWrapperOperatorFactory<>(
                         SimpleOperatorFactory.of(inputOp), broadcastWrapper);

Reply via email to