[FLINK-3888] allow registering a custom convergence criterion in delta iterations
- cleanups in iterations and aggregators code - add delta convergence criterion in the CollectionExecutor - add ITCases for delta custom convergence This closes #2606 Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/8085aa98 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/8085aa98 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/8085aa98 Branch: refs/heads/master Commit: 8085aa98333a052553f1155f65b5cc2728eb5ff8 Parents: 3ab97ae Author: vasia <[email protected]> Authored: Wed Oct 5 13:49:20 2016 +0200 Committer: vasia <[email protected]> Committed: Fri Oct 21 12:33:50 2016 +0200 ---------------------------------------------------------------------- .../common/aggregators/AggregatorRegistry.java | 10 +- .../common/operators/CollectionExecutor.java | 11 ++ .../api/java/operators/DeltaIteration.java | 32 +++- .../plantranslate/JobGraphGenerator.java | 13 +- .../task/IterationSynchronizationSinkTask.java | 43 +++-- .../runtime/operators/util/TaskConfig.java | 54 +++++- .../AggregatorConvergenceITCase.java | 182 +++++++++---------- .../aggregators/AggregatorsITCase.java | 82 ++++++--- 8 files changed, 272 insertions(+), 155 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-core/src/main/java/org/apache/flink/api/common/aggregators/AggregatorRegistry.java ---------------------------------------------------------------------- diff --git a/flink-core/src/main/java/org/apache/flink/api/common/aggregators/AggregatorRegistry.java b/flink-core/src/main/java/org/apache/flink/api/common/aggregators/AggregatorRegistry.java index 1d5c358..19663d1 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/aggregators/AggregatorRegistry.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/aggregators/AggregatorRegistry.java @@ -49,18 +49,14 @@ public class AggregatorRegistry { } this.registry.put(name, aggregator); } - - public Aggregator<?> unregisterAggregator(String name) { - return this.registry.remove(name); - } - + public Collection<AggregatorWithName<?>> getAllRegisteredAggregators() { ArrayList<AggregatorWithName<?>> list = new ArrayList<AggregatorWithName<?>>(this.registry.size()); for (Map.Entry<String, Aggregator<?>> entry : this.registry.entrySet()) { @SuppressWarnings("unchecked") Aggregator<Value> valAgg = (Aggregator<Value>) entry.getValue(); - list.add(new AggregatorWithName<Value>(entry.getKey(), valAgg)); + list.add(new AggregatorWithName<>(entry.getKey(), valAgg)); } return list; } @@ -72,7 +68,7 @@ public class AggregatorRegistry { throw new IllegalArgumentException("Name, aggregator, or convergence criterion must not be null"); } - Aggregator<?> genAgg = (Aggregator<?>) aggregator; + Aggregator<?> genAgg = aggregator; Aggregator<?> previous = this.registry.get(name); if (previous != null && previous != genAgg) { http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java ---------------------------------------------------------------------- diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java index d9240fe..a6fc17e 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java @@ -412,6 +412,9 @@ public class CollectionExecutor { aggregators.put(a.getName(), a.getAggregator()); } + String convCriterionAggName = iteration.getAggregators().getConvergenceCriterionAggregatorName(); + ConvergenceCriterion<Value> convCriterion = (ConvergenceCriterion<Value>) iteration.getAggregators().getConvergenceCriterion(); + final int maxIterations = iteration.getMaximumNumberOfIterations(); for (int superstep = 1; superstep <= maxIterations; superstep++) { @@ -442,6 +445,14 @@ public class CollectionExecutor { break; } + // evaluate the aggregator convergence criterion + if (convCriterion != null && convCriterionAggName != null) { + Value v = aggregators.get(convCriterionAggName).getAggregate(); + if (convCriterion.isConverged(superstep, v)) { + break; + } + } + // clear the dynamic results for (Operator<?> o : dynamics) { intermediateResults.remove(o); http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-java/src/main/java/org/apache/flink/api/java/operators/DeltaIteration.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/DeltaIteration.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/DeltaIteration.java index d53b499..b97a9de 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/DeltaIteration.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/DeltaIteration.java @@ -26,10 +26,12 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.InvalidProgramException; import org.apache.flink.api.common.aggregators.Aggregator; import org.apache.flink.api.common.aggregators.AggregatorRegistry; +import org.apache.flink.api.common.aggregators.ConvergenceCriterion; import org.apache.flink.api.common.operators.Keys; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.types.Value; import org.apache.flink.util.Preconditions; /** @@ -62,13 +64,13 @@ public class DeltaIteration<ST, WT> { private int parallelism = ExecutionConfig.PARALLELISM_DEFAULT; private boolean solutionSetUnManaged; - - + + public DeltaIteration(ExecutionEnvironment context, TypeInformation<ST> type, DataSet<ST> solutionSet, DataSet<WT> workset, Keys<ST> keys, int maxIterations) { initialSolutionSet = solutionSet; initialWorkset = workset; - solutionSetPlaceholder = new SolutionSetPlaceHolder<ST>(context, solutionSet.getType(), this); - worksetPlaceholder = new WorksetPlaceHolder<WT>(context, workset.getType()); + solutionSetPlaceholder = new SolutionSetPlaceHolder<>(context, solutionSet.getType(), this); + worksetPlaceholder = new WorksetPlaceHolder<>(context, workset.getType()); this.keys = keys; this.maxIterations = maxIterations; } @@ -210,6 +212,28 @@ public class DeltaIteration<ST, WT> { this.aggregators.registerAggregator(name, aggregator); return this; } + + /** + * Registers an {@link Aggregator} for the iteration together with a {@link ConvergenceCriterion}. For a general description + * of aggregators, see {@link #registerAggregator(String, Aggregator)} and {@link Aggregator}. + * At the end of each iteration, the convergence criterion takes the aggregator's global aggregate value and decides whether + * the iteration should terminate. A typical use case is to have an aggregator that sums up the total error of change + * in an iteration step and have to have a convergence criterion that signals termination as soon as the aggregate value + * is below a certain threshold. + * + * @param name The name under which the aggregator is registered. + * @param aggregator The aggregator class. + * @param convergenceCheck The convergence criterion. + * + * @return The DeltaIteration itself, to allow chaining function calls. + */ + @PublicEvolving + public <X extends Value> DeltaIteration<ST, WT> registerAggregationConvergenceCriterion( + String name, Aggregator<X> aggregator, ConvergenceCriterion<X> convergenceCheck) + { + this.aggregators.registerAggregationConvergenceCriterion(name, aggregator, convergenceCheck); + return this; + } /** * Gets the registry for aggregators for the iteration. http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java index 5ab1fbf..4ccfae3 100644 --- a/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java @@ -1513,14 +1513,21 @@ public class JobGraphGenerator implements Visitor<PlanNode> { String convAggName = aggs.getConvergenceCriterionAggregatorName(); ConvergenceCriterion<?> convCriterion = aggs.getConvergenceCriterion(); - + if (convCriterion != null || convAggName != null) { - throw new CompilerException("Error: Cannot use custom convergence criterion with workset iteration. Workset iterations have implicit convergence criterion where workset is empty."); + if (convCriterion == null) { + throw new CompilerException("Error: Convergence criterion aggregator set, but criterion is null."); + } + if (convAggName == null) { + throw new CompilerException("Error: Aggregator convergence criterion set, but aggregator is null."); + } + + syncConfig.setConvergenceCriterion(convAggName, convCriterion); } headConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new LongSumAggregator()); syncConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new LongSumAggregator()); - syncConfig.setConvergenceCriterion(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new WorksetEmptyConvergenceCriterion()); + syncConfig.setImplicitConvergenceCriterion(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new WorksetEmptyConvergenceCriterion()); } private String getDescriptionForUserCode(UserCodeWrapper<?> wrapper) { http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java index 66fb45b..11a8cfa 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java @@ -56,11 +56,15 @@ public class IterationSynchronizationSinkTask extends AbstractInvokable implemen private SyncEventHandler eventHandler; private ConvergenceCriterion<Value> convergenceCriterion; + + private ConvergenceCriterion<Value> implicitConvergenceCriterion; private Map<String, Aggregator<?>> aggregators; private String convergenceAggregatorName; + private String implicitConvergenceAggregatorName; + private int currentIteration = 1; private int maxNumberOfIterations; @@ -71,14 +75,14 @@ public class IterationSynchronizationSinkTask extends AbstractInvokable implemen @Override public void invoke() throws Exception { - this.headEventReader = new MutableRecordReader<IntValue>( + this.headEventReader = new MutableRecordReader<>( getEnvironment().getInputGate(0), getEnvironment().getTaskManagerInfo().getTmpDirectories()); TaskConfig taskConfig = new TaskConfig(getTaskConfiguration()); // store all aggregators - this.aggregators = new HashMap<String, Aggregator<?>>(); + this.aggregators = new HashMap<>(); for (AggregatorWithName<?> aggWithName : taskConfig.getIterationAggregators(getUserCodeClassLoader())) { aggregators.put(aggWithName.getName(), aggWithName.getAggregator()); } @@ -89,6 +93,13 @@ public class IterationSynchronizationSinkTask extends AbstractInvokable implemen convergenceAggregatorName = taskConfig.getConvergenceCriterionAggregatorName(); Preconditions.checkNotNull(convergenceAggregatorName); } + + // store the default aggregator convergence criterion + if (taskConfig.usesImplicitConvergenceCriterion()) { + implicitConvergenceCriterion = taskConfig.getImplicitConvergenceCriterion(getUserCodeClassLoader()); + implicitConvergenceAggregatorName = taskConfig.getImplicitConvergenceCriterionAggregatorName(); + Preconditions.checkNotNull(implicitConvergenceAggregatorName); + } maxNumberOfIterations = taskConfig.getNumberOfIterations(); @@ -102,7 +113,6 @@ public class IterationSynchronizationSinkTask extends AbstractInvokable implemen while (!terminationRequested()) { -// notifyMonitor(IterationMonitoring.Event.SYNC_STARTING, currentIteration); if (log.isInfoEnabled()) { log.info(formatLogString("starting iteration [" + currentIteration + "]")); } @@ -122,7 +132,6 @@ public class IterationSynchronizationSinkTask extends AbstractInvokable implemen requestTermination(); sendToAllWorkers(new TerminationEvent()); -// notifyMonitor(IterationMonitoring.Event.SYNC_FINISHED, currentIteration); } else { if (log.isInfoEnabled()) { log.info(formatLogString("signaling that all workers are done in iteration [" + currentIteration @@ -136,19 +145,11 @@ public class IterationSynchronizationSinkTask extends AbstractInvokable implemen for (Aggregator<?> agg : aggregators.values()) { agg.reset(); } - -// notifyMonitor(IterationMonitoring.Event.SYNC_FINISHED, currentIteration); currentIteration++; } } } -// protected void notifyMonitor(IterationMonitoring.Event event, int currentIteration) { -// if (log.isInfoEnabled()) { -// log.info(IterationMonitoring.logLine(getEnvironment().getJobID(), event, currentIteration, 1)); -// } -// } - private boolean checkForConvergence() { if (maxNumberOfIterations == currentIteration) { if (log.isInfoEnabled()) { @@ -175,6 +176,24 @@ public class IterationSynchronizationSinkTask extends AbstractInvokable implemen return true; } } + + if (implicitConvergenceAggregatorName != null) { + @SuppressWarnings("unchecked") + Aggregator<Value> aggregator = (Aggregator<Value>) aggregators.get(implicitConvergenceAggregatorName); + if (aggregator == null) { + throw new RuntimeException("Error: Aggregator for default convergence criterion was null."); + } + + Value aggregate = aggregator.getAggregate(); + + if (implicitConvergenceCriterion.isConverged(currentIteration, aggregate)) { + if (log.isInfoEnabled()) { + log.info(formatLogString("empty workset convergence reached after [" + currentIteration + + "] iterations, terminating...")); + } + return true; + } + } return false; } http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java index b598523..71c0405 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java @@ -37,6 +37,7 @@ import org.apache.flink.api.common.operators.util.UserCodeWrapper; import org.apache.flink.api.common.typeutils.TypeComparatorFactory; import org.apache.flink.api.common.typeutils.TypePairComparatorFactory; import org.apache.flink.api.common.typeutils.TypeSerializerFactory; +import org.apache.flink.api.java.operators.DeltaIteration; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.DelegatingConfiguration; import org.apache.flink.core.memory.DataInputViewStreamWrapper; @@ -197,6 +198,10 @@ public class TaskConfig implements Serializable { private static final String ITERATION_CONVERGENCE_CRITERION = "iterative.terminationCriterion"; private static final String ITERATION_CONVERGENCE_CRITERION_AGG_NAME = "iterative.terminationCriterion.agg.name"; + + private static final String ITERATION_IMPLICIT_CONVERGENCE_CRITERION = "iterative.implicit.terminationCriterion"; + + private static final String ITERATION_IMPLICIT_CONVERGENCE_CRITERION_AGG_NAME = "iterative.implicit.terminationCriterion.agg.name"; private static final String ITERATION_NUM_AGGREGATORS = "iterative.num-aggs"; @@ -992,16 +997,31 @@ public class TaskConfig implements Serializable { this.config.setString(ITERATION_CONVERGENCE_CRITERION_AGG_NAME, aggregatorName); } + /** + * Sets the default convergence criterion of a {@link DeltaIteration} + * + * @param aggregatorName + * @param convCriterion + */ + public void setImplicitConvergenceCriterion(String aggregatorName, ConvergenceCriterion<?> convCriterion) { + try { + InstantiationUtil.writeObjectToConfig(convCriterion, this.config, ITERATION_IMPLICIT_CONVERGENCE_CRITERION); + } catch (IOException e) { + throw new RuntimeException("Error while writing the implicit convergence criterion object to the task configuration."); + } + this.config.setString(ITERATION_IMPLICIT_CONVERGENCE_CRITERION_AGG_NAME, aggregatorName); + } + @SuppressWarnings("unchecked") public <T extends Value> ConvergenceCriterion<T> getConvergenceCriterion(ClassLoader cl) { - ConvergenceCriterion<T> convCriterionObj = null; + ConvergenceCriterion<T> convCriterionObj; try { - convCriterionObj = (ConvergenceCriterion<T>) InstantiationUtil.readObjectFromConfig( + convCriterionObj = InstantiationUtil.readObjectFromConfig( this.config, ITERATION_CONVERGENCE_CRITERION, cl); } catch (IOException e) { - throw new RuntimeException("Error while reading the covergence criterion object from the task configuration."); + throw new RuntimeException("Error while reading the convergence criterion object from the task configuration."); } catch (ClassNotFoundException e) { - throw new RuntimeException("Error while reading the covergence criterion object from the task configuration. " + + throw new RuntimeException("Error while reading the convergence criterion object from the task configuration. " + "ConvergenceCriterion class not found."); } if (convCriterionObj == null) { @@ -1017,6 +1037,32 @@ public class TaskConfig implements Serializable { public String getConvergenceCriterionAggregatorName() { return this.config.getString(ITERATION_CONVERGENCE_CRITERION_AGG_NAME, null); } + + @SuppressWarnings("unchecked") + public <T extends Value> ConvergenceCriterion<T> getImplicitConvergenceCriterion(ClassLoader cl) { + ConvergenceCriterion<T> convCriterionObj; + try { + convCriterionObj = InstantiationUtil.readObjectFromConfig( + this.config, ITERATION_IMPLICIT_CONVERGENCE_CRITERION, cl); + } catch (IOException e) { + throw new RuntimeException("Error while reading the default convergence criterion object from the task configuration."); + } catch (ClassNotFoundException e) { + throw new RuntimeException("Error while reading the default convergence criterion object from the task configuration. " + + "ConvergenceCriterion class not found."); + } + if (convCriterionObj == null) { + throw new NullPointerException(); + } + return convCriterionObj; + } + + public boolean usesImplicitConvergenceCriterion() { + return config.getBytes(ITERATION_IMPLICIT_CONVERGENCE_CRITERION, null) != null; + } + + public String getImplicitConvergenceCriterionAggregatorName() { + return this.config.getString(ITERATION_IMPLICIT_CONVERGENCE_CRITERION_AGG_NAME, null); + } public void setIsSolutionSetUpdate() { this.config.setBoolean(ITERATION_SOLUTION_SET_UPDATE, true); http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.java b/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.java index 941b31b..7bade80 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.java @@ -26,6 +26,7 @@ import org.apache.flink.api.common.aggregators.ConvergenceCriterion; import org.apache.flink.api.common.aggregators.LongSumAggregator; import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.functions.RichJoinFunction; +import org.apache.flink.api.java.operators.DeltaIteration; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.test.util.JavaProgramTestBase; @@ -52,47 +53,59 @@ public class AggregatorConvergenceITCase extends MultipleProgramsTestBase { public AggregatorConvergenceITCase(TestExecutionMode mode) { super(mode); } - + + final List<Tuple2<Long, Long>> verticesInput = Arrays.asList( + new Tuple2<>(1l,1l), + new Tuple2<>(2l,2l), + new Tuple2<>(3l,3l), + new Tuple2<>(4l,4l), + new Tuple2<>(5l,5l), + new Tuple2<>(6l,6l), + new Tuple2<>(7l,7l), + new Tuple2<>(8l,8l), + new Tuple2<>(9l,9l) + ); + + final List<Tuple2<Long, Long>> edgesInput = Arrays.asList( + new Tuple2<>(1l,2l), + new Tuple2<>(1l,3l), + new Tuple2<>(2l,3l), + new Tuple2<>(2l,4l), + new Tuple2<>(2l,1l), + new Tuple2<>(3l,1l), + new Tuple2<>(3l,2l), + new Tuple2<>(4l,2l), + new Tuple2<>(4l,6l), + new Tuple2<>(5l,6l), + new Tuple2<>(6l,4l), + new Tuple2<>(6l,5l), + new Tuple2<>(7l,8l), + new Tuple2<>(7l,9l), + new Tuple2<>(8l,7l), + new Tuple2<>(8l,9l), + new Tuple2<>(9l,7l), + new Tuple2<>(9l,8l) + ); + + final List<Tuple2<Long, Long>> expectedResult = Arrays.asList( + new Tuple2<>(1L,1L), + new Tuple2<>(2L,1L), + new Tuple2<>(3L,1L), + new Tuple2<>(4L,1L), + new Tuple2<>(5L,2L), + new Tuple2<>(6L,1L), + new Tuple2<>(7L,7L), + new Tuple2<>(8L,7L), + new Tuple2<>(9L,7L) + ); + @Test - public void testConnectedComponentsWithParametrizableConvergence() { - try { - List<Tuple2<Long, Long>> verticesInput = Arrays.asList( - new Tuple2<Long, Long>(1l,1l), - new Tuple2<Long, Long>(2l,2l), - new Tuple2<Long, Long>(3l,3l), - new Tuple2<Long, Long>(4l,4l), - new Tuple2<Long, Long>(5l,5l), - new Tuple2<Long, Long>(6l,6l), - new Tuple2<Long, Long>(7l,7l), - new Tuple2<Long, Long>(8l,8l), - new Tuple2<Long, Long>(9l,9l) - ); - - List<Tuple2<Long, Long>> edgesInput = Arrays.asList( - new Tuple2<Long, Long>(1l,2l), - new Tuple2<Long, Long>(1l,3l), - new Tuple2<Long, Long>(2l,3l), - new Tuple2<Long, Long>(2l,4l), - new Tuple2<Long, Long>(2l,1l), - new Tuple2<Long, Long>(3l,1l), - new Tuple2<Long, Long>(3l,2l), - new Tuple2<Long, Long>(4l,2l), - new Tuple2<Long, Long>(4l,6l), - new Tuple2<Long, Long>(5l,6l), - new Tuple2<Long, Long>(6l,4l), - new Tuple2<Long, Long>(6l,5l), - new Tuple2<Long, Long>(7l,8l), - new Tuple2<Long, Long>(7l,9l), - new Tuple2<Long, Long>(8l,7l), - new Tuple2<Long, Long>(8l,9l), - new Tuple2<Long, Long>(9l,7l), - new Tuple2<Long, Long>(9l,8l) - ); + public void testConnectedComponentsWithParametrizableConvergence() throws Exception { // name of the aggregator that checks for convergence final String UPDATED_ELEMENTS = "updated.elements.aggr"; - // the iteration stops if less than this number os elements change value + // the iteration stops if less than this number of elements change value final long convergence_threshold = 3; final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); @@ -100,8 +113,7 @@ public class AggregatorConvergenceITCase extends MultipleProgramsTestBase { DataSet<Tuple2<Long, Long>> initialSolutionSet = env.fromCollection(verticesInput); DataSet<Tuple2<Long, Long>> edges = env.fromCollection(edgesInput); - IterativeDataSet<Tuple2<Long, Long>> iteration = - initialSolutionSet.iterate(10); + IterativeDataSet<Tuple2<Long, Long>> iteration = initialSolutionSet.iterate(10); // register the convergence criterion iteration.registerAggregationConvergenceCriterion(UPDATED_ELEMENTS, @@ -117,62 +129,47 @@ public class AggregatorConvergenceITCase extends MultipleProgramsTestBase { List<Tuple2<Long, Long>> result = iteration.closeWith(updatedComponentId).collect(); Collections.sort(result, new JavaProgramTestBase.TupleComparator<Tuple2<Long, Long>>()); - - List<Tuple2<Long, Long>> expectedResult = Arrays.asList( - new Tuple2<Long, Long>(1L,1L), - new Tuple2<Long, Long>(2L,1L), - new Tuple2<Long, Long>(3L,1L), - new Tuple2<Long, Long>(4L,1L), - new Tuple2<Long, Long>(5L,2L), - new Tuple2<Long, Long>(6L,1L), - new Tuple2<Long, Long>(7L,7L), - new Tuple2<Long, Long>(8L,7L), - new Tuple2<Long, Long>(9L,7L) - ); assertEquals(expectedResult, result); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + } + + @Test + public void testDeltaConnectedComponentsWithParametrizableConvergence() throws Exception { + + // name of the aggregator that checks for convergence + final String UPDATED_ELEMENTS = "updated.elements.aggr"; + + // the iteration stops if less than this number of elements change value + final long convergence_threshold = 3; + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple2<Long, Long>> initialSolutionSet = env.fromCollection(verticesInput); + DataSet<Tuple2<Long, Long>> edges = env.fromCollection(edgesInput); + + DeltaIteration<Tuple2<Long, Long>, Tuple2<Long, Long>> iteration = + initialSolutionSet.iterateDelta(initialSolutionSet, 10, 0); + + // register the convergence criterion + iteration.registerAggregationConvergenceCriterion(UPDATED_ELEMENTS, + new LongSumAggregator(), new UpdatedElementsConvergenceCriterion(convergence_threshold)); + + DataSet<Tuple2<Long, Long>> verticesWithNewComponents = iteration.getWorkset().join(edges).where(0).equalTo(0) + .with(new NeighborWithComponentIDJoin()) + .groupBy(0).min(1); + + DataSet<Tuple2<Long, Long>> updatedComponentId = + verticesWithNewComponents.join(iteration.getSolutionSet()).where(0).equalTo(0) + .flatMap(new MinimumIdFilter(UPDATED_ELEMENTS)); + + List<Tuple2<Long, Long>> result = iteration.closeWith(updatedComponentId, updatedComponentId).collect(); + Collections.sort(result, new JavaProgramTestBase.TupleComparator<Tuple2<Long, Long>>()); + + assertEquals(expectedResult, result); } @Test - public void testParameterizableAggregator() { - try { - List<Tuple2<Long, Long>> verticesInput = Arrays.asList( - new Tuple2<Long, Long>(1l,1l), - new Tuple2<Long, Long>(2l,2l), - new Tuple2<Long, Long>(3l,3l), - new Tuple2<Long, Long>(4l,4l), - new Tuple2<Long, Long>(5l,5l), - new Tuple2<Long, Long>(6l,6l), - new Tuple2<Long, Long>(7l,7l), - new Tuple2<Long, Long>(8l,8l), - new Tuple2<Long, Long>(9l,9l) - ); - - List<Tuple2<Long, Long>> edgesInput = Arrays.asList( - new Tuple2<>(1l,2l), - new Tuple2<>(1l,3l), - new Tuple2<>(2l,3l), - new Tuple2<>(2l,4l), - new Tuple2<>(2l,1l), - new Tuple2<>(3l,1l), - new Tuple2<>(3l,2l), - new Tuple2<>(4l,2l), - new Tuple2<>(4l,6l), - new Tuple2<>(5l,6l), - new Tuple2<>(6l,4l), - new Tuple2<>(6l,5l), - new Tuple2<>(7l,8l), - new Tuple2<>(7l,9l), - new Tuple2<>(8l,7l), - new Tuple2<>(8l,9l), - new Tuple2<>(9l,7l), - new Tuple2<>(9l,8l) - ); + public void testParameterizableAggregator() throws Exception { final int MAX_ITERATIONS = 5; final String AGGREGATOR_NAME = "elements.in.component.aggregator"; @@ -213,7 +210,7 @@ public class AggregatorConvergenceITCase extends MultipleProgramsTestBase { new Tuple2<>(9L,7L) ); - // checkpogram result + // check program result assertEquals(expectedResult, result); // check aggregators @@ -226,11 +223,6 @@ public class AggregatorConvergenceITCase extends MultipleProgramsTestBase { assertEquals(4, aggr_values[1]); assertEquals(5, aggr_values[2]); assertEquals(6, aggr_values[3]); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } } // ------------------------------------------------------------------------ http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorsITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorsITCase.java b/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorsITCase.java index 4c5e955..042617d 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorsITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorsITCase.java @@ -272,6 +272,44 @@ public class AggregatorsITCase extends MultipleProgramsTestBase { + "5\n" + "5\n" + "5\n" + "5\n" + "5\n"; } + @Test + public void testConvergenceCriterionWithParameterForIterateDelta() throws Exception { + /* + * Test convergence criterion with parameter for iterate delta + */ + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(parallelism); + + DataSet<Tuple2<Integer, Integer>> initialSolutionSet = CollectionDataSets.getIntegerDataSet(env).map(new TupleMakerMap()); + + DeltaIteration<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> iteration = initialSolutionSet.iterateDelta( + initialSolutionSet, MAX_ITERATIONS, 0); + + // register aggregator + LongSumAggregator aggr = new LongSumAggregator(); + iteration.registerAggregator(NEGATIVE_ELEMENTS_AGGR, aggr); + + // register convergence criterion + iteration.registerAggregationConvergenceCriterion(NEGATIVE_ELEMENTS_AGGR, aggr, + new NegativeElementsConvergenceCriterionWithParam(3)); + + DataSet<Tuple2<Integer, Integer>> updatedDs = iteration.getWorkset().map(new AggregateAndSubtractOneDelta()); + + DataSet<Tuple2<Integer, Integer>> newElements = updatedDs.join(iteration.getSolutionSet()) + .where(0).equalTo(0).projectFirst(0, 1); + + DataSet<Tuple2<Integer, Integer>> iterationRes = iteration.closeWith(newElements, newElements); + DataSet<Integer> result = iterationRes.map(new ProjectSecondMapper()); + result.writeAsText(resultPath); + + env.execute(); + + expected = "-3\n" + "-2\n" + "-2\n" + "-1\n" + "-1\n" + + "-1\n" + "0\n" + "0\n" + "0\n" + "0\n" + + "1\n" + "1\n" + "1\n" + "1\n" + "1\n"; + } + @SuppressWarnings("serial") public static final class NegativeElementsConvergenceCriterion implements ConvergenceCriterion<LongValue> { @@ -313,9 +351,9 @@ public class AggregatorsITCase extends MultipleProgramsTestBase { @Override public Integer map(Integer value) { - Integer newValue = Integer.valueOf(value.intValue() - 1); + Integer newValue = value - 1; // count negative numbers - if (newValue.intValue() < 0) { + if (newValue < 0) { aggr.aggregate(1l); } return newValue; @@ -334,9 +372,9 @@ public class AggregatorsITCase extends MultipleProgramsTestBase { @Override public Integer map(Integer value) { - Integer newValue = Integer.valueOf(value.intValue() - 1); - // count numbers less then the aggregator parameter - if ( newValue.intValue() < aggr.getValue() ) { + Integer newValue = value - 1; + // count numbers less than the aggregator parameter + if ( newValue < aggr.getValue() ) { aggr.aggregate(1l); } return newValue; @@ -369,8 +407,8 @@ public class AggregatorsITCase extends MultipleProgramsTestBase { @Override public Tuple2<Integer, Integer> map(Integer value) { - Integer nodeId = Integer.valueOf(rnd.nextInt(100000)); - return new Tuple2<Integer, Integer>(nodeId, value); + Integer nodeId = rnd.nextInt(100000); + return new Tuple2<>(nodeId, value); } } @@ -398,7 +436,7 @@ public class AggregatorsITCase extends MultipleProgramsTestBase { @Override public Tuple2<Integer, Integer> map(Tuple2<Integer, Integer> value) { // count the elements that are equal to the superstep number - if (value.f1.intValue() == superstep) { + if (value.f1 == superstep) { aggr.aggregate(1l); } return value; @@ -436,48 +474,32 @@ public class AggregatorsITCase extends MultipleProgramsTestBase { } @SuppressWarnings("serial") - public static final class AggregateMapDeltaWithParam extends RichMapFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> { + public static final class AggregateAndSubtractOneDelta extends RichMapFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> { - private LongSumAggregatorWithParameter aggr; + private LongSumAggregator aggr; private LongValue previousAggr; private int superstep; @Override public void open(Configuration conf) { - aggr = getIterationRuntimeContext().getIterationAggregator(NEGATIVE_ELEMENTS_AGGR); superstep = getIterationRuntimeContext().getSuperstepNumber(); if (superstep > 1) { previousAggr = getIterationRuntimeContext().getPreviousIterationAggregate(NEGATIVE_ELEMENTS_AGGR); - // check previous aggregator value - switch(superstep) { - case 2: { - Assert.assertEquals(6, previousAggr.getValue()); - } - case 3: { - Assert.assertEquals(5, previousAggr.getValue()); - } - case 4: { - Assert.assertEquals(3, previousAggr.getValue()); - } - case 5: { - Assert.assertEquals(0, previousAggr.getValue()); - } - default: - } - Assert.assertEquals(superstep-1, previousAggr.getValue()); + Assert.assertEquals(superstep - 1, previousAggr.getValue()); } } @Override public Tuple2<Integer, Integer> map(Tuple2<Integer, Integer> value) { - // count the elements that are equal to the superstep number - if (value.f1.intValue() < aggr.getValue()) { + // count the ones + if (value.f1 == 1) { aggr.aggregate(1l); } + value.f1--; return value; } }
