guoweiM commented on a change in pull request #13:
URL: https://github.com/apache/flink-ml/pull/13#discussion_r735348117



##########
File path: 
flink-ml-iteration/src/main/java/org/apache/flink/iteration/IterationFactory.java
##########
@@ -109,11 +111,78 @@ public static DataStreamList createIteration(
                 mayHaveCriteria || 
iterationBodyResult.getTerminationCriteria() == null,
                 "The current iteration type does not support the termination 
criteria.");
 
-        // TODO: will consider the termination criteria in the next.
+        if (iterationBodyResult.getTerminationCriteria() != null) {
+            addCriteriaStream(
+                    iterationBodyResult.getTerminationCriteria(),
+                    iterationId,
+                    env,
+                    draftEnv,
+                    initVariableStreams,
+                    headStreams,
+                    totalInitVariableParallelism);
+        }
 
         return 
addOutputs(getActualDataStreams(iterationBodyResult.getOutputStreams(), 
draftEnv));
     }
 
+    private static void addCriteriaStream(
+            DataStream<?> draftCriteriaStream,
+            IterationID iterationId,
+            StreamExecutionEnvironment env,
+            DraftExecutionEnvironment draftEnv,
+            DataStreamList initVariableStreams,
+            DataStreamList headStreams,
+            int totalInitVariableParallelism) {
+        // deal with the criteria streams
+        DataStream<?> terminationCriteria = 
draftEnv.getActualStream(draftCriteriaStream.getId());
+        // It should always has the IterationRecordTypeInfo
+        checkState(
+                
terminationCriteria.getType().getClass().equals(IterationRecordTypeInfo.class),
+                "The termination criteria should always returns 
IterationRecord.");
+        TypeInformation<?> innerType =
+                ((IterationRecordTypeInfo<?>) 
terminationCriteria.getType()).getInnerTypeInfo();
+
+        DataStream<?> emptyCriteriaSource =
+                env.addSource(new DraftExecutionEnvironment.EmptySource())
+                        .returns(innerType)
+                        
.name(terminationCriteria.getTransformation().getName())
+                        .setParallelism(terminationCriteria.getParallelism());
+        DataStreamList criteriaSources = 
DataStreamList.of(emptyCriteriaSource);
+        DataStreamList criteriaInputs = addInputs(criteriaSources, false);
+        DataStreamList criteriaHeaders =
+                addHeads(
+                        criteriaSources,
+                        criteriaInputs,
+                        iterationId,
+                        totalInitVariableParallelism,
+                        true,
+                        initVariableStreams.size());
+        DataStreamList criteriaTails =
+                addTails(
+                        DataStreamList.of(terminationCriteria),
+                        iterationId,
+                        initVariableStreams.size());
+
+        String coLocationGroupKey = "co-" + iterationId.toHexString() + "-cri";
+        
criteriaHeaders.get(0).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
+        
criteriaTails.get(0).getTransformation().setCoLocationGroupKey(coLocationGroupKey);
+
+        // Since co-located task must be in the same region, we will have to 
add a fake op.
+        ((SingleOutputStreamOperator<?>) criteriaHeaders.get(0))
+                .getSideOutput(new OutputTag<IterationRecord<Integer>>("fake") 
{})

Review comment:
       Maybe we could use a final constant to express the "fake" ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to