gaoyunhaii commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r732431833



##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java
##########
@@ -96,31 +97,36 @@
 
     protected final StreamOperatorFactory<T> operatorFactory;
 
-    /** Metric group for the operator. */
     protected final OperatorMetricGroup metrics;
 
     protected final S wrappedOperator;
 
-    /** variables for withBroadcast operators. */
-    protected final MailboxExecutor mailboxExecutor;
-
-    protected final String[] broadcastStreamNames;
+    protected transient StreamOperatorStateHandler stateHandler;
 
-    protected final boolean[] isBlocking;
+    protected transient InternalTimeServiceManager<?> timeServiceManager;
 
+    protected final MailboxExecutor mailboxExecutor;
+    /** variables specific for withBroadcast functionality. */

Review comment:
       In general one empty line before each instance variable

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java
##########
@@ -196,17 +203,25 @@ public AbstractBroadcastWrapperOperator(
     }
 
     /**
-     * check whether all of broadcast variables are ready.
+     * checks whether all of broadcast variables are ready. Besides it 
maintains a state
+     * {broadcastVariablesReady} to avoiding invoking {@code 
BroadcastContext.isCacheFinished(...)}
+     * repeatedly. Finally, it sets broadcast variables for ${@link 
HasBroadcastVariable} if the
+     * broadcast variables are ready.
      *
-     * @return
+     * @return true if all broadcast variables are ready, false otherwise.
      */
     protected boolean areBroadcastVariablesReady() {
         if (broadcastVariablesReady) {
             return true;
         }
         for (String name : broadcastStreamNames) {
-            if (!BroadcastContext.isCacheFinished(Tuple2.of(name, 
indexOfSubtask))) {
+            if (!BroadcastContext.isCacheFinished(name + "-" + 
indexOfSubtask)) {
                 return false;
+            } else if (wrappedOperator instanceof HasBroadcastVariable) {
+                String key = name + "-" + indexOfSubtask;
+                String userKey = name.substring(name.indexOf('-') + 1);
+                ((HasBroadcastVariable) wrappedOperator)

Review comment:
       Use `OperatorUtils#processOperatorOrUdfIfSatisfy` instead since we may 
need to handle both of operators and UDF if either of them implements the 
interface. 

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -22,121 +22,168 @@
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
 import 
org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.common.broadcast.operator.HasBroadcastVariable;
 import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.ChainingStrategy;
 import 
org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.streaming.api.transformations.PhysicalTransformation;
+import org.apache.flink.util.AbstractID;
 import org.apache.flink.util.Preconditions;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
+import java.util.UUID;
 import java.util.function.Function;
 
+/** Utility class to support withBroadcast in DataStream. */
 public class BroadcastUtils {
+    /**
+     * supports withBroadcastStream in DataStream API. Broadcast data streams 
are available at all
+     * parallel instances of an operator that implements ${@link 
HasBroadcastVariable}. An operator
+     * that wants to access broadcast variables must implement ${@link 
HasBroadcastVariable}.
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first 
and further set by
+     * {@code HasBroadcastVariable.setBroadcastVariable(...)}. For now the 
non-broadcast input are
+     * cached by default to avoid the possible deadlocks.
+     *
+     * @param inputList non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is 
the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can 
access the broadcast
+     *     data streams and produce the output data stream. Note that users 
can add only one
+     *     operator in this function, otherwise it raises an exception.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(
+            List<DataStream<?>> inputList,
+            Map<String, DataStream<?>> bcStreams,
+            Function<List<DataStream<?>>, DataStream<OUT>> 
userDefinedFunction) {
+        Preconditions.checkArgument(inputList.size() > 0);
+
+        StreamExecutionEnvironment env = 
inputList.get(0).getExecutionEnvironment();
+        String[] broadcastNames = new String[bcStreams.size()];
+        DataStream<?>[] broadcastInputs = new DataStream[bcStreams.size()];
+        TypeInformation<?>[] broadcastInTypes = new 
TypeInformation[bcStreams.size()];
+        int idx = 0;
+        final String broadcastId = new AbstractID().toHexString();
+        for (String name : bcStreams.keySet()) {
+            broadcastNames[idx] = broadcastId + "-" + name;
+            broadcastInputs[idx] = bcStreams.get(name);
+            broadcastInTypes[idx] = broadcastInputs[idx].getType();
+            idx++;
+        }
 
+        DataStream<OUT> resultStream =
+                getResultStream(env, inputList, broadcastNames, 
userDefinedFunction);
+        TypeInformation outType = resultStream.getType();
+        final String coLocationKey = "broadcast-co-location-" + 
UUID.randomUUID();
+        DataStream<OUT> cachedBroadcastInputs =
+                cacheBroadcastVariables(
+                        env,
+                        broadcastNames,
+                        broadcastInputs,
+                        broadcastInTypes,
+                        resultStream.getParallelism(),
+                        outType);
+
+        boolean canCoLocate =
+                cachedBroadcastInputs.getTransformation() instanceof 
PhysicalTransformation
+                        && resultStream.getTransformation() instanceof 
PhysicalTransformation;
+        if (canCoLocate) {
+            ((PhysicalTransformation) 
cachedBroadcastInputs.getTransformation())
+                    .setChainingStrategy(ChainingStrategy.HEAD);
+            ((PhysicalTransformation) resultStream.getTransformation())
+                    .setChainingStrategy(ChainingStrategy.HEAD);
+        } else {
+            throw new UnsupportedOperationException(
+                    "cannot set chaining strategy on "
+                            + cachedBroadcastInputs.getTransformation()
+                            + " and "
+                            + resultStream.getTransformation()
+                            + ".");
+        }
+        
cachedBroadcastInputs.getTransformation().setCoLocationGroupKey(coLocationKey);
+        resultStream.getTransformation().setCoLocationGroupKey(coLocationKey);
+
+        return cachedBroadcastInputs.union(resultStream);
+    }
+
+    /**
+     * caches all broadcast iput data streams in static variables and returns 
the result multi-input
+     * stream operator. The result multi-input stream operator emits nothing 
and the only
+     * functionality of this operator is to cache all the input records in 
${@link
+     * BroadcastContext}.
+     *
+     * @param env execution environment.
+     * @param broadcastInputNames names of the broadcast input data streams.
+     * @param broadcastInputs list of the broadcast data streams.
+     * @param broadcastInTypes output types of the broadcast input data 
streams.
+     * @param parallelism parallelism.
+     * @param outType output type.
+     * @param <OUT>
+     * @return the result multi-input stream operator.
+     */
     private static <OUT> DataStream<OUT> cacheBroadcastVariables(
             StreamExecutionEnvironment env,
-            Map<String, DataStream<?>> bcStreams,
+            String[] broadcastInputNames,
+            DataStream<?>[] broadcastInputs,
+            TypeInformation<?>[] broadcastInTypes,
+            int parallelism,
             TypeInformation<OUT> outType) {
-        int numBroadcastInput = bcStreams.size();
-        String[] broadcastInputNames = bcStreams.keySet().toArray(new 
String[0]);
-        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new 
DataStream<?>[0]);
-        TypeInformation<?>[] broadcastInTypes = new 
TypeInformation[numBroadcastInput];
-        for (int i = 0; i < numBroadcastInput; i++) {
-            broadcastInTypes[i] = broadcastInputs[i].getType();
-        }
-
         MultipleInputTransformation<OUT> transformation =
                 new MultipleInputTransformation<OUT>(
                         "broadcastInputs",
                         new 
CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
                         outType,
-                        env.getParallelism());
-        for (DataStream<?> dataStream : bcStreams.values()) {
+                        parallelism);
+        for (DataStream<?> dataStream : broadcastInputs) {
             
transformation.addInput(dataStream.broadcast().getTransformation());
         }
         env.addOperator(transformation);
         return new MultipleConnectedStreams(env).transform(transformation);
     }
 
-    private static String getCoLocationKey(String[] broadcastNames) {
-        StringBuilder sb = new StringBuilder();
-        sb.append("Flink-ML-broadcast-co-location");
-        for (String name : broadcastNames) {
-            sb.append(name);
-        }
-        return sb.toString();
-    }
-
-    private static <OUT> DataStream<OUT> buildGraph(
+    /**
+     * uses {@link DraftExecutionEnvironment} to execute the 
userDefinedFunction and returns the
+     * resultStream.
+     *
+     * @param env execution environment.
+     * @param inputList non-broadcast input list.
+     * @param broadcastStreamNames names of the broadcast data streams.
+     * @param graphBuilder user-defined logic.
+     * @param <OUT> output type of the result stream.
+     * @return the result stream by applying user-defined logic on the input 
list.
+     */
+    private static <OUT> DataStream<OUT> getResultStream(
             StreamExecutionEnvironment env,
             List<DataStream<?>> inputList,
             String[] broadcastStreamNames,
             Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
         TypeInformation[] inTypes = new TypeInformation[inputList.size()];
         for (int i = 0; i < inputList.size(); i++) {
-            TypeInformation type = inputList.get(i).getType();
-            inTypes[i] = type;
+            inTypes[i] = inputList.get(i).getType();
         }
-        // blocking all non-broadcast input edges by default.
-        boolean[] isBlocking = new boolean[inTypes.length];
-        Arrays.fill(isBlocking, true);
+        // do not block all non-broadcast input edges by default.
+        boolean[] isBlocked = new boolean[inputList.size()];

Review comment:
       Do local variable always need explicit initialization in java?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/MultipleInputBroadcastWrapperOperator.java
##########
@@ -32,143 +31,73 @@
 import java.util.ArrayList;
 import java.util.List;
 
-/** Wrapper for WithBroadcastMultipleInputStreamOperator. */
+/** Wrapper for {@link MultipleInputStreamOperator} that implements {@link 
HasBroadcastVariable}. */
 public class MultipleInputBroadcastWrapperOperator<OUT>
         extends AbstractBroadcastWrapperOperator<OUT, 
MultipleInputStreamOperator<OUT>>
         implements MultipleInputStreamOperator<OUT> {
 
-    public MultipleInputBroadcastWrapperOperator(
+    private final List<Input> inputList;
+
+    MultipleInputBroadcastWrapperOperator(
             StreamOperatorParameters<OUT> parameters,
             StreamOperatorFactory<OUT> operatorFactory,
             String[] broadcastStreamNames,
             TypeInformation[] inTypes,
-            boolean[] isBlocking) {
-        super(parameters, operatorFactory, broadcastStreamNames, inTypes, 
isBlocking);
-    }
-
-    @Override
-    public List<Input> getInputs() {
-        List<Input> proxyInputs = new ArrayList<>();
+            boolean[] isBlocked) {
+        super(parameters, operatorFactory, broadcastStreamNames, inTypes, 
isBlocked);
+        inputList = new ArrayList<>();
         for (int i = 0; i < wrappedOperator.getInputs().size(); i++) {
-            proxyInputs.add(new ProxyInput(i));
+            inputList.add(new ProxyInput(i));
         }
-        return proxyInputs;
-    }
-
-    private <IN> void processElement(StreamRecord streamRecord, Input<IN> 
input) throws Exception {
-        input.processElement(streamRecord);
-    }
-
-    private <IN> void processWatermark(Watermark watermark, Input<IN> input) 
throws Exception {
-        input.processWatermark(watermark);
     }
 
-    private <IN> void processLatencyMarker(LatencyMarker latencyMarker, 
Input<IN> input)
-            throws Exception {
-        input.processLatencyMarker(latencyMarker);
-    }
-
-    private <IN> void setKeyContextElement(StreamRecord streamRecord, 
Input<IN> input)
-            throws Exception {
-        input.setKeyContextElement(streamRecord);
-    }
-
-    private <IN> void processWatermarkStatus(WatermarkStatus watermarkStatus, 
Input<IN> input)
-            throws Exception {
-        input.processWatermarkStatus(watermarkStatus);
+    @Override
+    public List<Input> getInputs() {
+        return inputList;
     }
 
     @Override
     public void endInput(int inputId) throws Exception {
-        ((ProxyInput) (getInputs().get(inputId - 1))).endInput();
+        endInputX(inputId - 1, x -> wrappedOperator.getInputs().get(inputId - 
1).processElement(x));

Review comment:
       nit: `wrappedOperator.getInputs().get(inputId - 1)::processElement` ?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java
##########
@@ -232,6 +247,78 @@ private OperatorMetricGroup createOperatorMetricGroup(
         }
     }
 
+    /**
+     * extracts common processing logic in subclasses' processing elements.
+     *
+     * @param streamRecord the input record.
+     * @param inputIndex input id, starts from zero.
+     * @param consumer the consumer function.
+     * @throws Exception
+     */
+    protected void processElementX(
+            StreamRecord streamRecord,
+            int inputIndex,
+            ThrowingConsumer<StreamRecord, Exception> consumer)
+            throws Exception {
+        if (!isBlocked[inputIndex]) {
+            if (areBroadcastVariablesReady()) {
+                
dataCacheWriters[inputIndex].finishCurrentSegmentAndStartNewSegment();

Review comment:
       Perhaps we also extract the method to process pending records to 
eliminate the repeat with `endInputX`? 

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/MultipleInputBroadcastWrapperOperator.java
##########
@@ -32,143 +31,73 @@
 import java.util.ArrayList;
 import java.util.List;
 
-/** Wrapper for WithBroadcastMultipleInputStreamOperator. */
+/** Wrapper for {@link MultipleInputStreamOperator} that implements {@link 
HasBroadcastVariable}. */
 public class MultipleInputBroadcastWrapperOperator<OUT>
         extends AbstractBroadcastWrapperOperator<OUT, 
MultipleInputStreamOperator<OUT>>
         implements MultipleInputStreamOperator<OUT> {
 
-    public MultipleInputBroadcastWrapperOperator(
+    private final List<Input> inputList;
+
+    MultipleInputBroadcastWrapperOperator(
             StreamOperatorParameters<OUT> parameters,
             StreamOperatorFactory<OUT> operatorFactory,
             String[] broadcastStreamNames,
             TypeInformation[] inTypes,
-            boolean[] isBlocking) {
-        super(parameters, operatorFactory, broadcastStreamNames, inTypes, 
isBlocking);
-    }
-
-    @Override
-    public List<Input> getInputs() {
-        List<Input> proxyInputs = new ArrayList<>();
+            boolean[] isBlocked) {
+        super(parameters, operatorFactory, broadcastStreamNames, inTypes, 
isBlocked);
+        inputList = new ArrayList<>();
         for (int i = 0; i < wrappedOperator.getInputs().size(); i++) {
-            proxyInputs.add(new ProxyInput(i));
+            inputList.add(new ProxyInput(i));
         }
-        return proxyInputs;
-    }
-
-    private <IN> void processElement(StreamRecord streamRecord, Input<IN> 
input) throws Exception {
-        input.processElement(streamRecord);
-    }
-
-    private <IN> void processWatermark(Watermark watermark, Input<IN> input) 
throws Exception {
-        input.processWatermark(watermark);
     }
 
-    private <IN> void processLatencyMarker(LatencyMarker latencyMarker, 
Input<IN> input)
-            throws Exception {
-        input.processLatencyMarker(latencyMarker);
-    }
-
-    private <IN> void setKeyContextElement(StreamRecord streamRecord, 
Input<IN> input)
-            throws Exception {
-        input.setKeyContextElement(streamRecord);
-    }
-
-    private <IN> void processWatermarkStatus(WatermarkStatus watermarkStatus, 
Input<IN> input)
-            throws Exception {
-        input.processWatermarkStatus(watermarkStatus);
+    @Override
+    public List<Input> getInputs() {
+        return inputList;
     }
 
     @Override
     public void endInput(int inputId) throws Exception {
-        ((ProxyInput) (getInputs().get(inputId - 1))).endInput();
+        endInputX(inputId - 1, x -> wrappedOperator.getInputs().get(inputId - 
1).processElement(x));
+        super.endInput(inputId);
     }
 
     private class ProxyInput<IN> implements Input<IN> {
 
-        private final int inputIdMinusOne;
+        /** input index of this input. */
+        private final int inputIndex;
 
         private final Input<IN> input;
 
-        public ProxyInput(int inputIdMinusOne) {
-            this.inputIdMinusOne = inputIdMinusOne;
-            this.input = wrappedOperator.getInputs().get(inputIdMinusOne);
+        public ProxyInput(int inputIndex) {
+            this.inputIndex = inputIndex;
+            this.input = wrappedOperator.getInputs().get(inputIndex);
         }
 
         @Override
         public void processElement(StreamRecord<IN> streamRecord) throws 
Exception {
-            if (isBlocking[inputIdMinusOne]) {
-                if (areBroadcastVariablesReady()) {
-                    
dataCacheWriters[inputIdMinusOne].finishCurrentSegmentAndStartNewSegment();
-                    segmentLists[inputIdMinusOne].addAll(
-                            
dataCacheWriters[inputIdMinusOne].getNewlyFinishedSegments());
-                    if (segmentLists[inputIdMinusOne].size() != 0) {
-                        DataCacheReader dataCacheReader =
-                                new DataCacheReader<>(
-                                        
inTypes[inputIdMinusOne].createSerializer(
-                                                
containingTask.getExecutionConfig()),
-                                        fileSystem,
-                                        segmentLists[inputIdMinusOne]);
-                        while (dataCacheReader.hasNext()) {
-                            
MultipleInputBroadcastWrapperOperator.this.processElement(
-                                    new StreamRecord(dataCacheReader.next()), 
input);
-                        }
-                    }
-                    segmentLists[inputIdMinusOne].clear();
-                    
MultipleInputBroadcastWrapperOperator.this.processElement(streamRecord, input);
-
-                } else {
-                    
dataCacheWriters[inputIdMinusOne].addRecord(streamRecord.getValue());
-                }
-
-            } else {
-                while (!areBroadcastVariablesReady()) {
-                    mailboxExecutor.yield();
-                }
-                
MultipleInputBroadcastWrapperOperator.this.processElement(streamRecord, input);
-            }
+            MultipleInputBroadcastWrapperOperator.this.processElementX(
+                    streamRecord, inputIndex, x -> input.processElement(x));

Review comment:
       Similarly this might be simplified to `input::processElement`

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
##########
@@ -18,106 +18,54 @@
 
 package org.apache.flink.ml.common.broadcast;
 
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.java.tuple.Tuple2;
 
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.concurrent.locks.ReentrantReadWriteLock;
+import java.util.concurrent.ConcurrentHashMap;
 
 public class BroadcastContext {
     /**
-     * Store broadcast DataStreams in a Map. The key is (broadcastName, 
partitionId) and the value
-     * is (isBroaddcastVariableReady, cacheList).
+     * stores broadcast data streams in a map. The key is 
broadcastName-partitionId and the value is
+     * (isBroadcastVariableReady, cacheList).
      */
-    private static Map<Tuple2<String, Integer>, Tuple2<Boolean, List<?>>> 
broadcastVariables =
-            new HashMap<>();
-    /**
-     * We use lock because we want to enable `getBroadcastVariable(String)` in 
a TM with multiple
-     * slots here. Note that using ConcurrentHashMap is not enough since we 
need "contains and get
-     * in an atomic operation".
-     */
-    private static ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
+    private static final Map<String, Tuple2<Boolean, List<?>>> 
BROADCAST_VARIABLES =
+            new ConcurrentHashMap<>();
 
-    public static void putBroadcastVariable(
-            Tuple2<String, Integer> key, Tuple2<Boolean, List<?>> variable) {
-        lock.writeLock().lock();
-        try {
-            broadcastVariables.put(key, variable);
-        } finally {
-            lock.writeLock().unlock();
-        }
+    @VisibleForTesting
+    public static void putBroadcastVariable(String key, Tuple2<Boolean, 
List<?>> variable) {
+        BROADCAST_VARIABLES.put(key, variable);
     }
 
     /**
-     * get the cached list with the given key.
-     *
-     * @param key
-     * @param <T>
-     * @return the cache list.
-     */
-    public static <T> List<T> getBroadcastVariable(Tuple2<String, Integer> 
key) {
-        lock.readLock().lock();
-        List<?> result = null;
-        try {
-            result = broadcastVariables.get(key).f1;
-        } finally {
-            lock.readLock().unlock();
-        }
-        return (List<T>) result;
-    }
-
-    /**
-     * get broadcast variables by name
+     * gets broadcast variables by name if this broadcast variable is fully 
cached.
      *
      * @param name
      * @param <T>
-     * @return
+     * @return the cache broadcast variable. Return null if it is not fully 
cached.
      */
+    @VisibleForTesting
     public static <T> List<T> getBroadcastVariable(String name) {
-        lock.readLock().lock();
-        List<?> result = null;
-        try {
-            for (Tuple2<String, Integer> nameAndPartitionId : 
broadcastVariables.keySet()) {
-                if (name.equals(nameAndPartitionId.f0) && 
isCacheFinished(nameAndPartitionId)) {
-                    result = broadcastVariables.get(nameAndPartitionId).f1;
-                    break;
-                }
-            }
-        } finally {
-            lock.readLock().unlock();
+        Tuple2<Boolean, List<?>> cacheReadyAndList = 
BROADCAST_VARIABLES.get(name);
+        if (cacheReadyAndList.f0) {
+            return (List<T>) cacheReadyAndList.f1;
         }
-        return (List<T>) result;
+        return null;
     }
 
-    public static void remove(Tuple2<String, Integer> key) {
-        lock.writeLock().lock();
-        try {
-            broadcastVariables.remove(key);
-        } finally {
-            lock.writeLock().unlock();
-        }
+    @VisibleForTesting
+    public static void remove(String key) {
+        BROADCAST_VARIABLES.remove(key);
     }
 
-    public static void markCacheFinished(Tuple2<String, Integer> key) {
-        lock.writeLock().lock();
-        try {
-            broadcastVariables.get(key).f0 = true;
-        } finally {
-            lock.writeLock().unlock();
-        }
+    @VisibleForTesting
+    public static void markCacheFinished(String key) {

Review comment:
       We should need to explicitly notify the wrapper operator by emitting one 
mail? otherwise the wrapper operator may stalled in the `endInputX`. 

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -22,121 +22,168 @@
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
 import 
org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.common.broadcast.operator.HasBroadcastVariable;
 import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.ChainingStrategy;
 import 
org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.streaming.api.transformations.PhysicalTransformation;
+import org.apache.flink.util.AbstractID;
 import org.apache.flink.util.Preconditions;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
+import java.util.UUID;
 import java.util.function.Function;
 
+/** Utility class to support withBroadcast in DataStream. */
 public class BroadcastUtils {
+    /**
+     * supports withBroadcastStream in DataStream API. Broadcast data streams 
are available at all
+     * parallel instances of an operator that implements ${@link 
HasBroadcastVariable}. An operator
+     * that wants to access broadcast variables must implement ${@link 
HasBroadcastVariable}.
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first 
and further set by
+     * {@code HasBroadcastVariable.setBroadcastVariable(...)}. For now the 
non-broadcast input are
+     * cached by default to avoid the possible deadlocks.
+     *
+     * @param inputList non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is 
the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can 
access the broadcast
+     *     data streams and produce the output data stream. Note that users 
can add only one
+     *     operator in this function, otherwise it raises an exception.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(

Review comment:
       @SuppressWarnings({"rawtypes", "unchecked"})

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */
+    private final String[] broadcastStreamNames;
+    /** types of input DataStreams. */
+    private final TypeInformation[] inTypes;
+    /** whether each input is blocking or not. */
+    private final boolean[] isBlocking;
+
+    public BroadcastWrapper(String[] broadcastStreamNames, TypeInformation[] 
inTypes) {
+        this(broadcastStreamNames, inTypes, new boolean[inTypes.length]);
+    }
+
+    public BroadcastWrapper(
+            String[] broadcastStreamNames, TypeInformation[] inTypes, 
boolean[] isBlocking) {
+        Preconditions.checkState(inTypes.length == isBlocking.length);
+        this.broadcastStreamNames = broadcastStreamNames;

Review comment:
       For the following assignments, not the first line

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -22,121 +22,168 @@
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
 import 
org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.common.broadcast.operator.HasBroadcastVariable;
 import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.ChainingStrategy;
 import 
org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.streaming.api.transformations.PhysicalTransformation;
+import org.apache.flink.util.AbstractID;
 import org.apache.flink.util.Preconditions;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
+import java.util.UUID;
 import java.util.function.Function;
 
+/** Utility class to support withBroadcast in DataStream. */
 public class BroadcastUtils {
+    /**
+     * supports withBroadcastStream in DataStream API. Broadcast data streams 
are available at all
+     * parallel instances of an operator that implements ${@link 
HasBroadcastVariable}. An operator
+     * that wants to access broadcast variables must implement ${@link 
HasBroadcastVariable}.
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first 
and further set by
+     * {@code HasBroadcastVariable.setBroadcastVariable(...)}. For now the 
non-broadcast input are
+     * cached by default to avoid the possible deadlocks.
+     *
+     * @param inputList non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is 
the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can 
access the broadcast
+     *     data streams and produce the output data stream. Note that users 
can add only one
+     *     operator in this function, otherwise it raises an exception.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(
+            List<DataStream<?>> inputList,
+            Map<String, DataStream<?>> bcStreams,
+            Function<List<DataStream<?>>, DataStream<OUT>> 
userDefinedFunction) {
+        Preconditions.checkArgument(inputList.size() > 0);
+
+        StreamExecutionEnvironment env = 
inputList.get(0).getExecutionEnvironment();
+        String[] broadcastNames = new String[bcStreams.size()];
+        DataStream<?>[] broadcastInputs = new DataStream[bcStreams.size()];
+        TypeInformation<?>[] broadcastInTypes = new 
TypeInformation[bcStreams.size()];
+        int idx = 0;
+        final String broadcastId = new AbstractID().toHexString();
+        for (String name : bcStreams.keySet()) {
+            broadcastNames[idx] = broadcastId + "-" + name;
+            broadcastInputs[idx] = bcStreams.get(name);
+            broadcastInTypes[idx] = broadcastInputs[idx].getType();
+            idx++;
+        }
 
+        DataStream<OUT> resultStream =
+                getResultStream(env, inputList, broadcastNames, 
userDefinedFunction);
+        TypeInformation outType = resultStream.getType();
+        final String coLocationKey = "broadcast-co-location-" + 
UUID.randomUUID();
+        DataStream<OUT> cachedBroadcastInputs =
+                cacheBroadcastVariables(
+                        env,
+                        broadcastNames,
+                        broadcastInputs,
+                        broadcastInTypes,
+                        resultStream.getParallelism(),
+                        outType);
+
+        boolean canCoLocate =
+                cachedBroadcastInputs.getTransformation() instanceof 
PhysicalTransformation
+                        && resultStream.getTransformation() instanceof 
PhysicalTransformation;
+        if (canCoLocate) {
+            ((PhysicalTransformation) 
cachedBroadcastInputs.getTransformation())
+                    .setChainingStrategy(ChainingStrategy.HEAD);
+            ((PhysicalTransformation) resultStream.getTransformation())
+                    .setChainingStrategy(ChainingStrategy.HEAD);
+        } else {
+            throw new UnsupportedOperationException(
+                    "cannot set chaining strategy on "
+                            + cachedBroadcastInputs.getTransformation()
+                            + " and "
+                            + resultStream.getTransformation()
+                            + ".");
+        }
+        
cachedBroadcastInputs.getTransformation().setCoLocationGroupKey(coLocationKey);
+        resultStream.getTransformation().setCoLocationGroupKey(coLocationKey);
+
+        return cachedBroadcastInputs.union(resultStream);
+    }
+
+    /**
+     * caches all broadcast iput data streams in static variables and returns 
the result multi-input
+     * stream operator. The result multi-input stream operator emits nothing 
and the only
+     * functionality of this operator is to cache all the input records in 
${@link
+     * BroadcastContext}.
+     *
+     * @param env execution environment.
+     * @param broadcastInputNames names of the broadcast input data streams.
+     * @param broadcastInputs list of the broadcast data streams.
+     * @param broadcastInTypes output types of the broadcast input data 
streams.
+     * @param parallelism parallelism.
+     * @param outType output type.
+     * @param <OUT>
+     * @return the result multi-input stream operator.
+     */
     private static <OUT> DataStream<OUT> cacheBroadcastVariables(
             StreamExecutionEnvironment env,
-            Map<String, DataStream<?>> bcStreams,
+            String[] broadcastInputNames,
+            DataStream<?>[] broadcastInputs,
+            TypeInformation<?>[] broadcastInTypes,
+            int parallelism,
             TypeInformation<OUT> outType) {
-        int numBroadcastInput = bcStreams.size();
-        String[] broadcastInputNames = bcStreams.keySet().toArray(new 
String[0]);
-        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new 
DataStream<?>[0]);
-        TypeInformation<?>[] broadcastInTypes = new 
TypeInformation[numBroadcastInput];
-        for (int i = 0; i < numBroadcastInput; i++) {
-            broadcastInTypes[i] = broadcastInputs[i].getType();
-        }
-
         MultipleInputTransformation<OUT> transformation =
                 new MultipleInputTransformation<OUT>(

Review comment:
       `<OUT>` -> `<>`

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -22,121 +22,168 @@
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
 import 
org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.common.broadcast.operator.HasBroadcastVariable;
 import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.ChainingStrategy;
 import 
org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.streaming.api.transformations.PhysicalTransformation;
+import org.apache.flink.util.AbstractID;
 import org.apache.flink.util.Preconditions;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
+import java.util.UUID;
 import java.util.function.Function;
 
+/** Utility class to support withBroadcast in DataStream. */
 public class BroadcastUtils {
+    /**
+     * supports withBroadcastStream in DataStream API. Broadcast data streams 
are available at all
+     * parallel instances of an operator that implements ${@link 
HasBroadcastVariable}. An operator
+     * that wants to access broadcast variables must implement ${@link 
HasBroadcastVariable}.
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first 
and further set by
+     * {@code HasBroadcastVariable.setBroadcastVariable(...)}. For now the 
non-broadcast input are
+     * cached by default to avoid the possible deadlocks.
+     *
+     * @param inputList non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is 
the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can 
access the broadcast
+     *     data streams and produce the output data stream. Note that users 
can add only one
+     *     operator in this function, otherwise it raises an exception.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(
+            List<DataStream<?>> inputList,
+            Map<String, DataStream<?>> bcStreams,
+            Function<List<DataStream<?>>, DataStream<OUT>> 
userDefinedFunction) {
+        Preconditions.checkArgument(inputList.size() > 0);
+
+        StreamExecutionEnvironment env = 
inputList.get(0).getExecutionEnvironment();
+        String[] broadcastNames = new String[bcStreams.size()];
+        DataStream<?>[] broadcastInputs = new DataStream[bcStreams.size()];
+        TypeInformation<?>[] broadcastInTypes = new 
TypeInformation[bcStreams.size()];
+        int idx = 0;
+        final String broadcastId = new AbstractID().toHexString();
+        for (String name : bcStreams.keySet()) {
+            broadcastNames[idx] = broadcastId + "-" + name;
+            broadcastInputs[idx] = bcStreams.get(name);
+            broadcastInTypes[idx] = broadcastInputs[idx].getType();
+            idx++;
+        }
 
+        DataStream<OUT> resultStream =
+                getResultStream(env, inputList, broadcastNames, 
userDefinedFunction);
+        TypeInformation outType = resultStream.getType();
+        final String coLocationKey = "broadcast-co-location-" + 
UUID.randomUUID();
+        DataStream<OUT> cachedBroadcastInputs =
+                cacheBroadcastVariables(
+                        env,
+                        broadcastNames,
+                        broadcastInputs,
+                        broadcastInTypes,
+                        resultStream.getParallelism(),
+                        outType);
+
+        boolean canCoLocate =
+                cachedBroadcastInputs.getTransformation() instanceof 
PhysicalTransformation
+                        && resultStream.getTransformation() instanceof 
PhysicalTransformation;
+        if (canCoLocate) {
+            ((PhysicalTransformation) 
cachedBroadcastInputs.getTransformation())
+                    .setChainingStrategy(ChainingStrategy.HEAD);
+            ((PhysicalTransformation) resultStream.getTransformation())
+                    .setChainingStrategy(ChainingStrategy.HEAD);
+        } else {
+            throw new UnsupportedOperationException(
+                    "cannot set chaining strategy on "
+                            + cachedBroadcastInputs.getTransformation()
+                            + " and "
+                            + resultStream.getTransformation()
+                            + ".");
+        }
+        
cachedBroadcastInputs.getTransformation().setCoLocationGroupKey(coLocationKey);
+        resultStream.getTransformation().setCoLocationGroupKey(coLocationKey);
+
+        return cachedBroadcastInputs.union(resultStream);
+    }
+
+    /**
+     * caches all broadcast iput data streams in static variables and returns 
the result multi-input
+     * stream operator. The result multi-input stream operator emits nothing 
and the only
+     * functionality of this operator is to cache all the input records in 
${@link
+     * BroadcastContext}.
+     *
+     * @param env execution environment.
+     * @param broadcastInputNames names of the broadcast input data streams.
+     * @param broadcastInputs list of the broadcast data streams.
+     * @param broadcastInTypes output types of the broadcast input data 
streams.
+     * @param parallelism parallelism.
+     * @param outType output type.
+     * @param <OUT>
+     * @return the result multi-input stream operator.
+     */
     private static <OUT> DataStream<OUT> cacheBroadcastVariables(
             StreamExecutionEnvironment env,
-            Map<String, DataStream<?>> bcStreams,
+            String[] broadcastInputNames,
+            DataStream<?>[] broadcastInputs,
+            TypeInformation<?>[] broadcastInTypes,
+            int parallelism,
             TypeInformation<OUT> outType) {
-        int numBroadcastInput = bcStreams.size();
-        String[] broadcastInputNames = bcStreams.keySet().toArray(new 
String[0]);
-        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new 
DataStream<?>[0]);
-        TypeInformation<?>[] broadcastInTypes = new 
TypeInformation[numBroadcastInput];
-        for (int i = 0; i < numBroadcastInput; i++) {
-            broadcastInTypes[i] = broadcastInputs[i].getType();
-        }
-
         MultipleInputTransformation<OUT> transformation =
                 new MultipleInputTransformation<OUT>(
                         "broadcastInputs",
                         new 
CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
                         outType,
-                        env.getParallelism());
-        for (DataStream<?> dataStream : bcStreams.values()) {
+                        parallelism);
+        for (DataStream<?> dataStream : broadcastInputs) {
             
transformation.addInput(dataStream.broadcast().getTransformation());
         }
         env.addOperator(transformation);
         return new MultipleConnectedStreams(env).transform(transformation);
     }
 
-    private static String getCoLocationKey(String[] broadcastNames) {
-        StringBuilder sb = new StringBuilder();
-        sb.append("Flink-ML-broadcast-co-location");
-        for (String name : broadcastNames) {
-            sb.append(name);
-        }
-        return sb.toString();
-    }
-
-    private static <OUT> DataStream<OUT> buildGraph(
+    /**
+     * uses {@link DraftExecutionEnvironment} to execute the 
userDefinedFunction and returns the
+     * resultStream.
+     *
+     * @param env execution environment.
+     * @param inputList non-broadcast input list.
+     * @param broadcastStreamNames names of the broadcast data streams.
+     * @param graphBuilder user-defined logic.
+     * @param <OUT> output type of the result stream.
+     * @return the result stream by applying user-defined logic on the input 
list.
+     */
+    private static <OUT> DataStream<OUT> getResultStream(
             StreamExecutionEnvironment env,
             List<DataStream<?>> inputList,
             String[] broadcastStreamNames,
             Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
         TypeInformation[] inTypes = new TypeInformation[inputList.size()];
         for (int i = 0; i < inputList.size(); i++) {
-            TypeInformation type = inputList.get(i).getType();
-            inTypes[i] = type;
+            inTypes[i] = inputList.get(i).getType();
         }
-        // blocking all non-broadcast input edges by default.
-        boolean[] isBlocking = new boolean[inTypes.length];
-        Arrays.fill(isBlocking, true);
+        // do not block all non-broadcast input edges by default.
+        boolean[] isBlocked = new boolean[inputList.size()];
         DraftExecutionEnvironment draftEnv =
                 new DraftExecutionEnvironment(
-                        env, new BroadcastWrapper<>(broadcastStreamNames, 
inTypes, isBlocking));
+                        env, new BroadcastWrapper<>(broadcastStreamNames, 
inTypes, isBlocked));
 
         List<DataStream<?>> draftSources = new ArrayList<>();
         for (int i = 0; i < inputList.size(); i++) {
             draftSources.add(draftEnv.addDraftSource(inputList.get(i), 
inputList.get(i).getType()));
         }
         DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
-
+        Preconditions.checkState(
+                draftEnv.getStreamGraph(false).getStreamNodes().size() == 1 + 
inputList.size(),

Review comment:
       Are you sure this would work? When calling `getStreamGraph` the list of 
`transformations` would be cleared from the env.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to