[FLINK-3888] allow registering a custom convergence criterion in delta 
iterations

- cleanups in iterations and aggregators code
- add delta convergence criterion in the CollectionExecutor
- add ITCases for delta custom convergence

This closes #2606


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/8085aa98
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/8085aa98
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/8085aa98

Branch: refs/heads/master
Commit: 8085aa98333a052553f1155f65b5cc2728eb5ff8
Parents: 3ab97ae
Author: vasia <[email protected]>
Authored: Wed Oct 5 13:49:20 2016 +0200
Committer: vasia <[email protected]>
Committed: Fri Oct 21 12:33:50 2016 +0200

----------------------------------------------------------------------
 .../common/aggregators/AggregatorRegistry.java  |  10 +-
 .../common/operators/CollectionExecutor.java    |  11 ++
 .../api/java/operators/DeltaIteration.java      |  32 +++-
 .../plantranslate/JobGraphGenerator.java        |  13 +-
 .../task/IterationSynchronizationSinkTask.java  |  43 +++--
 .../runtime/operators/util/TaskConfig.java      |  54 +++++-
 .../AggregatorConvergenceITCase.java            | 182 +++++++++----------
 .../aggregators/AggregatorsITCase.java          |  82 ++++++---
 8 files changed, 272 insertions(+), 155 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-core/src/main/java/org/apache/flink/api/common/aggregators/AggregatorRegistry.java
----------------------------------------------------------------------
diff --git 
a/flink-core/src/main/java/org/apache/flink/api/common/aggregators/AggregatorRegistry.java
 
b/flink-core/src/main/java/org/apache/flink/api/common/aggregators/AggregatorRegistry.java
index 1d5c358..19663d1 100644
--- 
a/flink-core/src/main/java/org/apache/flink/api/common/aggregators/AggregatorRegistry.java
+++ 
b/flink-core/src/main/java/org/apache/flink/api/common/aggregators/AggregatorRegistry.java
@@ -49,18 +49,14 @@ public class AggregatorRegistry {
                }
                this.registry.put(name, aggregator);
        }
-       
-       public Aggregator<?> unregisterAggregator(String name) {
-               return this.registry.remove(name);
-       }
-       
+
        public Collection<AggregatorWithName<?>> getAllRegisteredAggregators() {
                ArrayList<AggregatorWithName<?>> list = new 
ArrayList<AggregatorWithName<?>>(this.registry.size());
                
                for (Map.Entry<String, Aggregator<?>> entry : 
this.registry.entrySet()) {
                        @SuppressWarnings("unchecked")
                        Aggregator<Value> valAgg = (Aggregator<Value>) 
entry.getValue();
-                       list.add(new AggregatorWithName<Value>(entry.getKey(), 
valAgg));
+                       list.add(new AggregatorWithName<>(entry.getKey(), 
valAgg));
                }
                return list;
        }
@@ -72,7 +68,7 @@ public class AggregatorRegistry {
                        throw new IllegalArgumentException("Name, aggregator, 
or convergence criterion must not be null");
                }
                
-               Aggregator<?> genAgg = (Aggregator<?>) aggregator;
+               Aggregator<?> genAgg = aggregator;
                
                Aggregator<?> previous = this.registry.get(name);
                if (previous != null && previous != genAgg) {

http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java
----------------------------------------------------------------------
diff --git 
a/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java
 
b/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java
index d9240fe..a6fc17e 100644
--- 
a/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java
+++ 
b/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java
@@ -412,6 +412,9 @@ public class CollectionExecutor {
                        aggregators.put(a.getName(), a.getAggregator());
                }
 
+               String convCriterionAggName = 
iteration.getAggregators().getConvergenceCriterionAggregatorName();
+               ConvergenceCriterion<Value> convCriterion = 
(ConvergenceCriterion<Value>) 
iteration.getAggregators().getConvergenceCriterion();
+
                final int maxIterations = 
iteration.getMaximumNumberOfIterations();
 
                for (int superstep = 1; superstep <= maxIterations; 
superstep++) {
@@ -442,6 +445,14 @@ public class CollectionExecutor {
                                break;
                        }
 
+                       // evaluate the aggregator convergence criterion
+                       if (convCriterion != null && convCriterionAggName != 
null) {
+                               Value v = 
aggregators.get(convCriterionAggName).getAggregate();
+                               if (convCriterion.isConverged(superstep, v)) {
+                                       break;
+                               }
+                       }
+
                        // clear the dynamic results
                        for (Operator<?> o : dynamics) {
                                intermediateResults.remove(o);

http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-java/src/main/java/org/apache/flink/api/java/operators/DeltaIteration.java
----------------------------------------------------------------------
diff --git 
a/flink-java/src/main/java/org/apache/flink/api/java/operators/DeltaIteration.java
 
b/flink-java/src/main/java/org/apache/flink/api/java/operators/DeltaIteration.java
index d53b499..b97a9de 100644
--- 
a/flink-java/src/main/java/org/apache/flink/api/java/operators/DeltaIteration.java
+++ 
b/flink-java/src/main/java/org/apache/flink/api/java/operators/DeltaIteration.java
@@ -26,10 +26,12 @@ import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.InvalidProgramException;
 import org.apache.flink.api.common.aggregators.Aggregator;
 import org.apache.flink.api.common.aggregators.AggregatorRegistry;
+import org.apache.flink.api.common.aggregators.ConvergenceCriterion;
 import org.apache.flink.api.common.operators.Keys;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.DataSet;
 import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.types.Value;
 import org.apache.flink.util.Preconditions;
 
 /**
@@ -62,13 +64,13 @@ public class DeltaIteration<ST, WT> {
        private int parallelism = ExecutionConfig.PARALLELISM_DEFAULT;
        
        private boolean solutionSetUnManaged;
-       
-       
+
+
        public DeltaIteration(ExecutionEnvironment context, TypeInformation<ST> 
type, DataSet<ST> solutionSet, DataSet<WT> workset, Keys<ST> keys, int 
maxIterations) {
                initialSolutionSet = solutionSet;
                initialWorkset = workset;
-               solutionSetPlaceholder = new 
SolutionSetPlaceHolder<ST>(context, solutionSet.getType(), this);
-               worksetPlaceholder = new WorksetPlaceHolder<WT>(context, 
workset.getType());
+               solutionSetPlaceholder = new SolutionSetPlaceHolder<>(context, 
solutionSet.getType(), this);
+               worksetPlaceholder = new WorksetPlaceHolder<>(context, 
workset.getType());
                this.keys = keys;
                this.maxIterations = maxIterations;
        }
@@ -210,6 +212,28 @@ public class DeltaIteration<ST, WT> {
                this.aggregators.registerAggregator(name, aggregator);
                return this;
        }
+
+       /**
+        * Registers an {@link Aggregator} for the iteration together with a 
{@link ConvergenceCriterion}. For a general description
+        * of aggregators, see {@link #registerAggregator(String, Aggregator)} 
and {@link Aggregator}.
+        * At the end of each iteration, the convergence criterion takes the 
aggregator's global aggregate value and decides whether
+        * the iteration should terminate. A typical use case is to have an 
aggregator that sums up the total error of change
+        * in an iteration step and have to have a convergence criterion that 
signals termination as soon as the aggregate value
+        * is below a certain threshold.
+        *
+        * @param name The name under which the aggregator is registered.
+        * @param aggregator The aggregator class.
+        * @param convergenceCheck The convergence criterion.
+        *
+        * @return The DeltaIteration itself, to allow chaining function calls.
+        */
+       @PublicEvolving
+       public <X extends Value> DeltaIteration<ST, WT> 
registerAggregationConvergenceCriterion(
+                       String name, Aggregator<X> aggregator, 
ConvergenceCriterion<X> convergenceCheck)
+       {
+               this.aggregators.registerAggregationConvergenceCriterion(name, 
aggregator, convergenceCheck);
+               return this;
+       }
        
        /**
         * Gets the registry for aggregators for the iteration.

http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
----------------------------------------------------------------------
diff --git 
a/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
 
b/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
index 5ab1fbf..4ccfae3 100644
--- 
a/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
+++ 
b/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
@@ -1513,14 +1513,21 @@ public class JobGraphGenerator implements 
Visitor<PlanNode> {
                
                String convAggName = 
aggs.getConvergenceCriterionAggregatorName();
                ConvergenceCriterion<?> convCriterion = 
aggs.getConvergenceCriterion();
-               
+
                if (convCriterion != null || convAggName != null) {
-                       throw new CompilerException("Error: Cannot use custom 
convergence criterion with workset iteration. Workset iterations have implicit 
convergence criterion where workset is empty.");
+                       if (convCriterion == null) {
+                               throw new CompilerException("Error: Convergence 
criterion aggregator set, but criterion is null.");
+                       }
+                       if (convAggName == null) {
+                               throw new CompilerException("Error: Aggregator 
convergence criterion set, but aggregator is null.");
+                       }
+
+                       syncConfig.setConvergenceCriterion(convAggName, 
convCriterion);
                }
                
                
headConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME,
 new LongSumAggregator());
                
syncConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME,
 new LongSumAggregator());
-               
syncConfig.setConvergenceCriterion(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME,
 new WorksetEmptyConvergenceCriterion());
+               
syncConfig.setImplicitConvergenceCriterion(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME,
 new WorksetEmptyConvergenceCriterion());
        }
        
        private String getDescriptionForUserCode(UserCodeWrapper<?> wrapper) {

http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java
index 66fb45b..11a8cfa 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java
@@ -56,11 +56,15 @@ public class IterationSynchronizationSinkTask extends 
AbstractInvokable implemen
        private SyncEventHandler eventHandler;
 
        private ConvergenceCriterion<Value> convergenceCriterion;
+
+       private ConvergenceCriterion<Value> implicitConvergenceCriterion;
        
        private Map<String, Aggregator<?>> aggregators;
 
        private String convergenceAggregatorName;
 
+       private String implicitConvergenceAggregatorName;
+
        private int currentIteration = 1;
        
        private int maxNumberOfIterations;
@@ -71,14 +75,14 @@ public class IterationSynchronizationSinkTask extends 
AbstractInvokable implemen
        
        @Override
        public void invoke() throws Exception {
-               this.headEventReader = new MutableRecordReader<IntValue>(
+               this.headEventReader = new MutableRecordReader<>(
                                getEnvironment().getInputGate(0),
                                
getEnvironment().getTaskManagerInfo().getTmpDirectories());
 
                TaskConfig taskConfig = new TaskConfig(getTaskConfiguration());
                
                // store all aggregators
-               this.aggregators = new HashMap<String, Aggregator<?>>();
+               this.aggregators = new HashMap<>();
                for (AggregatorWithName<?> aggWithName : 
taskConfig.getIterationAggregators(getUserCodeClassLoader())) {
                        aggregators.put(aggWithName.getName(), 
aggWithName.getAggregator());
                }
@@ -89,6 +93,13 @@ public class IterationSynchronizationSinkTask extends 
AbstractInvokable implemen
                        convergenceAggregatorName = 
taskConfig.getConvergenceCriterionAggregatorName();
                        Preconditions.checkNotNull(convergenceAggregatorName);
                }
+
+               // store the default aggregator convergence criterion
+               if (taskConfig.usesImplicitConvergenceCriterion()) {
+                       implicitConvergenceCriterion = 
taskConfig.getImplicitConvergenceCriterion(getUserCodeClassLoader());
+                       implicitConvergenceAggregatorName = 
taskConfig.getImplicitConvergenceCriterionAggregatorName();
+                       
Preconditions.checkNotNull(implicitConvergenceAggregatorName);
+               }
                
                maxNumberOfIterations = taskConfig.getNumberOfIterations();
                
@@ -102,7 +113,6 @@ public class IterationSynchronizationSinkTask extends 
AbstractInvokable implemen
                
                while (!terminationRequested()) {
 
-//                     notifyMonitor(IterationMonitoring.Event.SYNC_STARTING, 
currentIteration);
                        if (log.isInfoEnabled()) {
                                log.info(formatLogString("starting iteration [" 
+ currentIteration + "]"));
                        }
@@ -122,7 +132,6 @@ public class IterationSynchronizationSinkTask extends 
AbstractInvokable implemen
 
                                requestTermination();
                                sendToAllWorkers(new TerminationEvent());
-//                             
notifyMonitor(IterationMonitoring.Event.SYNC_FINISHED, currentIteration);
                        } else {
                                if (log.isInfoEnabled()) {
                                        log.info(formatLogString("signaling 
that all workers are done in iteration [" + currentIteration
@@ -136,19 +145,11 @@ public class IterationSynchronizationSinkTask extends 
AbstractInvokable implemen
                                for (Aggregator<?> agg : aggregators.values()) {
                                        agg.reset();
                                }
-                               
-//                             
notifyMonitor(IterationMonitoring.Event.SYNC_FINISHED, currentIteration);
                                currentIteration++;
                        }
                }
        }
 
-//     protected void notifyMonitor(IterationMonitoring.Event event, int 
currentIteration) {
-//             if (log.isInfoEnabled()) {
-//                     
log.info(IterationMonitoring.logLine(getEnvironment().getJobID(), event, 
currentIteration, 1));
-//             }
-//     }
-
        private boolean checkForConvergence() {
                if (maxNumberOfIterations == currentIteration) {
                        if (log.isInfoEnabled()) {
@@ -175,6 +176,24 @@ public class IterationSynchronizationSinkTask extends 
AbstractInvokable implemen
                                return true;
                        }
                }
+
+               if (implicitConvergenceAggregatorName != null) {
+                       @SuppressWarnings("unchecked")
+                       Aggregator<Value> aggregator = (Aggregator<Value>) 
aggregators.get(implicitConvergenceAggregatorName);
+                       if (aggregator == null) {
+                               throw new RuntimeException("Error: Aggregator 
for default convergence criterion was null.");
+                       }
+
+                       Value aggregate = aggregator.getAggregate();
+
+                       if 
(implicitConvergenceCriterion.isConverged(currentIteration, aggregate)) {
+                               if (log.isInfoEnabled()) {
+                                       log.info(formatLogString("empty workset 
convergence reached after [" + currentIteration
+                                                       + "] iterations, 
terminating..."));
+                               }
+                               return true;
+                       }
+               }
                
                return false;
        }

http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java
index b598523..71c0405 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java
@@ -37,6 +37,7 @@ import 
org.apache.flink.api.common.operators.util.UserCodeWrapper;
 import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
 import org.apache.flink.api.common.typeutils.TypePairComparatorFactory;
 import org.apache.flink.api.common.typeutils.TypeSerializerFactory;
+import org.apache.flink.api.java.operators.DeltaIteration;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.DelegatingConfiguration;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
@@ -197,6 +198,10 @@ public class TaskConfig implements Serializable {
        private static final String ITERATION_CONVERGENCE_CRITERION = 
"iterative.terminationCriterion";
        
        private static final String ITERATION_CONVERGENCE_CRITERION_AGG_NAME = 
"iterative.terminationCriterion.agg.name";
+
+       private static final String ITERATION_IMPLICIT_CONVERGENCE_CRITERION = 
"iterative.implicit.terminationCriterion";
+
+       private static final String 
ITERATION_IMPLICIT_CONVERGENCE_CRITERION_AGG_NAME = 
"iterative.implicit.terminationCriterion.agg.name";
        
        private static final String ITERATION_NUM_AGGREGATORS = 
"iterative.num-aggs";
        
@@ -992,16 +997,31 @@ public class TaskConfig implements Serializable {
                this.config.setString(ITERATION_CONVERGENCE_CRITERION_AGG_NAME, 
aggregatorName);
        }
 
+       /**
+        * Sets the default convergence criterion of a {@link DeltaIteration}
+        *
+        * @param aggregatorName
+        * @param convCriterion
+        */
+       public void setImplicitConvergenceCriterion(String aggregatorName, 
ConvergenceCriterion<?> convCriterion) {
+               try {
+                       InstantiationUtil.writeObjectToConfig(convCriterion, 
this.config, ITERATION_IMPLICIT_CONVERGENCE_CRITERION);
+               } catch (IOException e) {
+                       throw new RuntimeException("Error while writing the 
implicit convergence criterion object to the task configuration.");
+               }
+               
this.config.setString(ITERATION_IMPLICIT_CONVERGENCE_CRITERION_AGG_NAME, 
aggregatorName);
+       }
+
        @SuppressWarnings("unchecked")
        public <T extends Value> ConvergenceCriterion<T> 
getConvergenceCriterion(ClassLoader cl) {
-               ConvergenceCriterion<T> convCriterionObj = null;
+               ConvergenceCriterion<T> convCriterionObj;
                try {
-                       convCriterionObj = (ConvergenceCriterion<T>) 
InstantiationUtil.readObjectFromConfig(
+                       convCriterionObj = 
InstantiationUtil.readObjectFromConfig(
                        this.config, ITERATION_CONVERGENCE_CRITERION, cl);
                } catch (IOException e) {
-                       throw new RuntimeException("Error while reading the 
covergence criterion object from the task configuration.");
+                       throw new RuntimeException("Error while reading the 
convergence criterion object from the task configuration.");
                } catch (ClassNotFoundException e) {
-                       throw new RuntimeException("Error while reading the 
covergence criterion object from the task configuration. " +
+                       throw new RuntimeException("Error while reading the 
convergence criterion object from the task configuration. " +
                                        "ConvergenceCriterion class not 
found.");
                }
                if (convCriterionObj == null) {
@@ -1017,6 +1037,32 @@ public class TaskConfig implements Serializable {
        public String getConvergenceCriterionAggregatorName() {
                return 
this.config.getString(ITERATION_CONVERGENCE_CRITERION_AGG_NAME, null);
        }
+
+       @SuppressWarnings("unchecked")
+       public <T extends Value> ConvergenceCriterion<T> 
getImplicitConvergenceCriterion(ClassLoader cl) {
+               ConvergenceCriterion<T> convCriterionObj;
+               try {
+                       convCriterionObj = 
InstantiationUtil.readObjectFromConfig(
+                                       this.config, 
ITERATION_IMPLICIT_CONVERGENCE_CRITERION, cl);
+               } catch (IOException e) {
+                       throw new RuntimeException("Error while reading the 
default convergence criterion object from the task configuration.");
+               } catch (ClassNotFoundException e) {
+                       throw new RuntimeException("Error while reading the 
default convergence criterion object from the task configuration. " +
+                                       "ConvergenceCriterion class not 
found.");
+               }
+               if (convCriterionObj == null) {
+                       throw new NullPointerException();
+               }
+               return convCriterionObj;
+       }
+
+       public boolean usesImplicitConvergenceCriterion() {
+               return 
config.getBytes(ITERATION_IMPLICIT_CONVERGENCE_CRITERION, null) != null;
+       }
+
+       public String getImplicitConvergenceCriterionAggregatorName() {
+               return 
this.config.getString(ITERATION_IMPLICIT_CONVERGENCE_CRITERION_AGG_NAME, null);
+       }
        
        public void setIsSolutionSetUpdate() {
                this.config.setBoolean(ITERATION_SOLUTION_SET_UPDATE, true);

http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.java
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.java
index 941b31b..7bade80 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.java
@@ -26,6 +26,7 @@ import 
org.apache.flink.api.common.aggregators.ConvergenceCriterion;
 import org.apache.flink.api.common.aggregators.LongSumAggregator;
 import org.apache.flink.api.common.functions.RichFlatMapFunction;
 import org.apache.flink.api.common.functions.RichJoinFunction;
+import org.apache.flink.api.java.operators.DeltaIteration;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.test.util.JavaProgramTestBase;
@@ -52,47 +53,59 @@ public class AggregatorConvergenceITCase extends 
MultipleProgramsTestBase {
        public AggregatorConvergenceITCase(TestExecutionMode mode) {
                super(mode);
        }
-       
+
+       final List<Tuple2<Long, Long>> verticesInput = Arrays.asList(
+                       new Tuple2<>(1l,1l),
+                       new Tuple2<>(2l,2l),
+                       new Tuple2<>(3l,3l),
+                       new Tuple2<>(4l,4l),
+                       new Tuple2<>(5l,5l),
+                       new Tuple2<>(6l,6l),
+                       new Tuple2<>(7l,7l),
+                       new Tuple2<>(8l,8l),
+                       new Tuple2<>(9l,9l)
+       );
+
+       final List<Tuple2<Long, Long>> edgesInput = Arrays.asList(
+                       new Tuple2<>(1l,2l),
+                       new Tuple2<>(1l,3l),
+                       new Tuple2<>(2l,3l),
+                       new Tuple2<>(2l,4l),
+                       new Tuple2<>(2l,1l),
+                       new Tuple2<>(3l,1l),
+                       new Tuple2<>(3l,2l),
+                       new Tuple2<>(4l,2l),
+                       new Tuple2<>(4l,6l),
+                       new Tuple2<>(5l,6l),
+                       new Tuple2<>(6l,4l),
+                       new Tuple2<>(6l,5l),
+                       new Tuple2<>(7l,8l),
+                       new Tuple2<>(7l,9l),
+                       new Tuple2<>(8l,7l),
+                       new Tuple2<>(8l,9l),
+                       new Tuple2<>(9l,7l),
+                       new Tuple2<>(9l,8l)
+       );
+
+       final List<Tuple2<Long, Long>> expectedResult = Arrays.asList(
+                       new Tuple2<>(1L,1L),
+                       new Tuple2<>(2L,1L),
+                       new Tuple2<>(3L,1L),
+                       new Tuple2<>(4L,1L),
+                       new Tuple2<>(5L,2L),
+                       new Tuple2<>(6L,1L),
+                       new Tuple2<>(7L,7L),
+                       new Tuple2<>(8L,7L),
+                       new Tuple2<>(9L,7L)
+       );
+
        @Test
-       public void testConnectedComponentsWithParametrizableConvergence() {
-               try {
-                       List<Tuple2<Long, Long>> verticesInput = Arrays.asList(
-                                       new Tuple2<Long, Long>(1l,1l),
-                                       new Tuple2<Long, Long>(2l,2l),
-                                       new Tuple2<Long, Long>(3l,3l),
-                                       new Tuple2<Long, Long>(4l,4l),
-                                       new Tuple2<Long, Long>(5l,5l),
-                                       new Tuple2<Long, Long>(6l,6l),
-                                       new Tuple2<Long, Long>(7l,7l),
-                                       new Tuple2<Long, Long>(8l,8l),
-                                       new Tuple2<Long, Long>(9l,9l)
-                       );
-                       
-                       List<Tuple2<Long, Long>> edgesInput = Arrays.asList(
-                                       new Tuple2<Long, Long>(1l,2l),
-                                       new Tuple2<Long, Long>(1l,3l),
-                                       new Tuple2<Long, Long>(2l,3l),
-                                       new Tuple2<Long, Long>(2l,4l),
-                                       new Tuple2<Long, Long>(2l,1l),
-                                       new Tuple2<Long, Long>(3l,1l),
-                                       new Tuple2<Long, Long>(3l,2l),
-                                       new Tuple2<Long, Long>(4l,2l),
-                                       new Tuple2<Long, Long>(4l,6l),
-                                       new Tuple2<Long, Long>(5l,6l),
-                                       new Tuple2<Long, Long>(6l,4l),
-                                       new Tuple2<Long, Long>(6l,5l),
-                                       new Tuple2<Long, Long>(7l,8l),
-                                       new Tuple2<Long, Long>(7l,9l),
-                                       new Tuple2<Long, Long>(8l,7l),
-                                       new Tuple2<Long, Long>(8l,9l),
-                                       new Tuple2<Long, Long>(9l,7l),
-                                       new Tuple2<Long, Long>(9l,8l)
-                       );
+       public void testConnectedComponentsWithParametrizableConvergence() 
throws Exception {
 
                        // name of the aggregator that checks for convergence
                        final String UPDATED_ELEMENTS = "updated.elements.aggr";
 
-                       // the iteration stops if less than this number os 
elements change value
+                       // the iteration stops if less than this number of 
elements change value
                        final long convergence_threshold = 3;
 
                        final ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
@@ -100,8 +113,7 @@ public class AggregatorConvergenceITCase extends 
MultipleProgramsTestBase {
                        DataSet<Tuple2<Long, Long>> initialSolutionSet = 
env.fromCollection(verticesInput);
                        DataSet<Tuple2<Long, Long>> edges = 
env.fromCollection(edgesInput);
 
-                       IterativeDataSet<Tuple2<Long, Long>> iteration =
-                                       initialSolutionSet.iterate(10);
+                       IterativeDataSet<Tuple2<Long, Long>> iteration = 
initialSolutionSet.iterate(10);
 
                        // register the convergence criterion
                        
iteration.registerAggregationConvergenceCriterion(UPDATED_ELEMENTS,
@@ -117,62 +129,47 @@ public class AggregatorConvergenceITCase extends 
MultipleProgramsTestBase {
 
                        List<Tuple2<Long, Long>> result = 
iteration.closeWith(updatedComponentId).collect();
                        Collections.sort(result, new 
JavaProgramTestBase.TupleComparator<Tuple2<Long, Long>>());
-
-                       List<Tuple2<Long, Long>> expectedResult = Arrays.asList(
-                                       new Tuple2<Long, Long>(1L,1L),
-                                       new Tuple2<Long, Long>(2L,1L),
-                                       new Tuple2<Long, Long>(3L,1L),
-                                       new Tuple2<Long, Long>(4L,1L),
-                                       new Tuple2<Long, Long>(5L,2L),
-                                       new Tuple2<Long, Long>(6L,1L),
-                                       new Tuple2<Long, Long>(7L,7L),
-                                       new Tuple2<Long, Long>(8L,7L),
-                                       new Tuple2<Long, Long>(9L,7L)
-                       );
                        
                        assertEquals(expectedResult, result);
-               }
-               catch (Exception e) {
-                       e.printStackTrace();
-                       fail(e.getMessage());
-               }
+       }
+
+       @Test
+       public void testDeltaConnectedComponentsWithParametrizableConvergence() 
throws Exception {
+
+                       // name of the aggregator that checks for convergence
+                       final String UPDATED_ELEMENTS = "updated.elements.aggr";
+
+                       // the iteration stops if less than this number of 
elements change value
+                       final long convergence_threshold = 3;
+
+                       final ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+
+                       DataSet<Tuple2<Long, Long>> initialSolutionSet = 
env.fromCollection(verticesInput);
+                       DataSet<Tuple2<Long, Long>> edges = 
env.fromCollection(edgesInput);
+
+                       DeltaIteration<Tuple2<Long, Long>, Tuple2<Long, Long>> 
iteration =
+                                       
initialSolutionSet.iterateDelta(initialSolutionSet, 10, 0);
+
+                       // register the convergence criterion
+                       
iteration.registerAggregationConvergenceCriterion(UPDATED_ELEMENTS,
+                                       new LongSumAggregator(), new 
UpdatedElementsConvergenceCriterion(convergence_threshold));
+
+                       DataSet<Tuple2<Long, Long>> verticesWithNewComponents = 
iteration.getWorkset().join(edges).where(0).equalTo(0)
+                                       .with(new NeighborWithComponentIDJoin())
+                                       .groupBy(0).min(1);
+
+                       DataSet<Tuple2<Long, Long>> updatedComponentId =
+                                       
verticesWithNewComponents.join(iteration.getSolutionSet()).where(0).equalTo(0)
+                                                       .flatMap(new 
MinimumIdFilter(UPDATED_ELEMENTS));
+
+                       List<Tuple2<Long, Long>> result = 
iteration.closeWith(updatedComponentId, updatedComponentId).collect();
+                       Collections.sort(result, new 
JavaProgramTestBase.TupleComparator<Tuple2<Long, Long>>());
+
+                       assertEquals(expectedResult, result);
        }
        
        @Test
-       public void testParameterizableAggregator() {
-               try {
-                       List<Tuple2<Long, Long>> verticesInput = Arrays.asList(
-                               new Tuple2<Long, Long>(1l,1l),
-                               new Tuple2<Long, Long>(2l,2l),
-                               new Tuple2<Long, Long>(3l,3l),
-                               new Tuple2<Long, Long>(4l,4l),
-                               new Tuple2<Long, Long>(5l,5l),
-                               new Tuple2<Long, Long>(6l,6l),
-                               new Tuple2<Long, Long>(7l,7l),
-                               new Tuple2<Long, Long>(8l,8l),
-                               new Tuple2<Long, Long>(9l,9l)
-                       );
-                       
-                       List<Tuple2<Long, Long>> edgesInput = Arrays.asList(
-                                       new Tuple2<>(1l,2l),
-                                       new Tuple2<>(1l,3l),
-                                       new Tuple2<>(2l,3l),
-                                       new Tuple2<>(2l,4l),
-                                       new Tuple2<>(2l,1l),
-                                       new Tuple2<>(3l,1l),
-                                       new Tuple2<>(3l,2l),
-                                       new Tuple2<>(4l,2l),
-                                       new Tuple2<>(4l,6l),
-                                       new Tuple2<>(5l,6l),
-                                       new Tuple2<>(6l,4l),
-                                       new Tuple2<>(6l,5l),
-                                       new Tuple2<>(7l,8l),
-                                       new Tuple2<>(7l,9l),
-                                       new Tuple2<>(8l,7l),
-                                       new Tuple2<>(8l,9l),
-                                       new Tuple2<>(9l,7l),
-                                       new Tuple2<>(9l,8l)
-                       );
+       public void testParameterizableAggregator() throws Exception {
 
                        final int MAX_ITERATIONS = 5;
                        final String AGGREGATOR_NAME = 
"elements.in.component.aggregator";
@@ -213,7 +210,7 @@ public class AggregatorConvergenceITCase extends 
MultipleProgramsTestBase {
                                        new Tuple2<>(9L,7L)
                        );
 
-                       // checkpogram result
+                       // check program result
                        assertEquals(expectedResult, result);
 
                        // check aggregators
@@ -226,11 +223,6 @@ public class AggregatorConvergenceITCase extends 
MultipleProgramsTestBase {
                        assertEquals(4, aggr_values[1]);
                        assertEquals(5, aggr_values[2]);
                        assertEquals(6, aggr_values[3]);
-               }
-               catch (Exception e) {
-                       e.printStackTrace();
-                       fail(e.getMessage());
-               }
        }
        
        // 
------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorsITCase.java
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorsITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorsITCase.java
index 4c5e955..042617d 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorsITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorsITCase.java
@@ -272,6 +272,44 @@ public class AggregatorsITCase extends 
MultipleProgramsTestBase {
                                + "5\n" + "5\n" + "5\n" + "5\n" + "5\n";
        }
 
+       @Test
+       public void testConvergenceCriterionWithParameterForIterateDelta() 
throws Exception {
+               /*
+                * Test convergence criterion with parameter for iterate delta
+                */
+
+               final ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+               env.setParallelism(parallelism);
+
+               DataSet<Tuple2<Integer, Integer>> initialSolutionSet = 
CollectionDataSets.getIntegerDataSet(env).map(new TupleMakerMap());
+
+               DeltaIteration<Tuple2<Integer, Integer>, Tuple2<Integer, 
Integer>> iteration = initialSolutionSet.iterateDelta(
+                               initialSolutionSet, MAX_ITERATIONS, 0);
+
+               // register aggregator
+               LongSumAggregator aggr = new LongSumAggregator();
+               iteration.registerAggregator(NEGATIVE_ELEMENTS_AGGR, aggr);
+
+               // register convergence criterion
+               
iteration.registerAggregationConvergenceCriterion(NEGATIVE_ELEMENTS_AGGR, aggr,
+                               new 
NegativeElementsConvergenceCriterionWithParam(3));
+
+               DataSet<Tuple2<Integer, Integer>> updatedDs = 
iteration.getWorkset().map(new AggregateAndSubtractOneDelta());
+
+               DataSet<Tuple2<Integer, Integer>> newElements = 
updatedDs.join(iteration.getSolutionSet())
+                               .where(0).equalTo(0).projectFirst(0, 1);
+
+               DataSet<Tuple2<Integer, Integer>> iterationRes = 
iteration.closeWith(newElements, newElements);
+               DataSet<Integer> result = iterationRes.map(new 
ProjectSecondMapper());
+               result.writeAsText(resultPath);
+
+               env.execute();
+
+               expected = "-3\n" + "-2\n" + "-2\n" + "-1\n" + "-1\n"
+                               + "-1\n" + "0\n" + "0\n" + "0\n" + "0\n"
+                               + "1\n" + "1\n" + "1\n" + "1\n" + "1\n";
+       }
+
        @SuppressWarnings("serial")
        public static final class NegativeElementsConvergenceCriterion 
implements ConvergenceCriterion<LongValue> {
 
@@ -313,9 +351,9 @@ public class AggregatorsITCase extends 
MultipleProgramsTestBase {
 
                @Override
                public Integer map(Integer value) {
-                       Integer newValue = Integer.valueOf(value.intValue() - 
1);
+                       Integer newValue = value - 1;
                        // count negative numbers
-                       if (newValue.intValue() < 0) {
+                       if (newValue < 0) {
                                aggr.aggregate(1l);
                        }
                        return newValue;
@@ -334,9 +372,9 @@ public class AggregatorsITCase extends 
MultipleProgramsTestBase {
 
                @Override
                public Integer map(Integer value) {
-                       Integer newValue = Integer.valueOf(value.intValue() - 
1);
-                       // count numbers less then the aggregator parameter
-                       if ( newValue.intValue() < aggr.getValue() ) {
+                       Integer newValue = value - 1;
+                       // count numbers less than the aggregator parameter
+                       if ( newValue < aggr.getValue() ) {
                                aggr.aggregate(1l);
                        }
                        return newValue;
@@ -369,8 +407,8 @@ public class AggregatorsITCase extends 
MultipleProgramsTestBase {
 
                @Override
                public Tuple2<Integer, Integer> map(Integer value) {
-                       Integer nodeId = Integer.valueOf(rnd.nextInt(100000));
-                       return new Tuple2<Integer, Integer>(nodeId, value);
+                       Integer nodeId = rnd.nextInt(100000);
+                       return new Tuple2<>(nodeId, value);
                }
 
        }
@@ -398,7 +436,7 @@ public class AggregatorsITCase extends 
MultipleProgramsTestBase {
                @Override
                public Tuple2<Integer, Integer> map(Tuple2<Integer, Integer> 
value) {
                        // count the elements that are equal to the superstep 
number
-                       if (value.f1.intValue() == superstep) {
+                       if (value.f1 == superstep) {
                                aggr.aggregate(1l);
                        }
                        return value;
@@ -436,48 +474,32 @@ public class AggregatorsITCase extends 
MultipleProgramsTestBase {
        }
 
        @SuppressWarnings("serial")
-       public static final class AggregateMapDeltaWithParam extends 
RichMapFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> {
+       public static final class AggregateAndSubtractOneDelta extends 
RichMapFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> {
 
-               private LongSumAggregatorWithParameter aggr;
+               private LongSumAggregator aggr;
                private LongValue previousAggr;
                private int superstep;
 
                @Override
                public void open(Configuration conf) {
-
                        aggr = 
getIterationRuntimeContext().getIterationAggregator(NEGATIVE_ELEMENTS_AGGR);
                        superstep = 
getIterationRuntimeContext().getSuperstepNumber();
 
                        if (superstep > 1) {
                                previousAggr = 
getIterationRuntimeContext().getPreviousIterationAggregate(NEGATIVE_ELEMENTS_AGGR);
-
                                // check previous aggregator value
-                               switch(superstep) {
-                                       case 2: {
-                                               Assert.assertEquals(6, 
previousAggr.getValue());
-                                       }
-                                       case 3: {
-                                               Assert.assertEquals(5, 
previousAggr.getValue());
-                                       }
-                                       case 4: {
-                                               Assert.assertEquals(3, 
previousAggr.getValue());
-                                       }
-                                       case 5: {
-                                               Assert.assertEquals(0, 
previousAggr.getValue());
-                                       }
-                                       default:
-                               }
-                               Assert.assertEquals(superstep-1, 
previousAggr.getValue());
+                               Assert.assertEquals(superstep - 1, 
previousAggr.getValue());
                        }
 
                }
 
                @Override
                public Tuple2<Integer, Integer> map(Tuple2<Integer, Integer> 
value) {
-                       // count the elements that are equal to the superstep 
number
-                       if (value.f1.intValue() < aggr.getValue()) {
+                       // count the ones
+                       if (value.f1 == 1) {
                                aggr.aggregate(1l);
                        }
+                       value.f1--;
                        return value;
                }
        }

Reply via email to