Repository: flink Updated Branches: refs/heads/master b201f8664 -> 65545c2ed
[FLINK-3806] [gelly] Revert use of DataSet.count() This closes #2036 Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/36ad78c0 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/36ad78c0 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/36ad78c0 Branch: refs/heads/master Commit: 36ad78c0821fdae0a69371c67602dd2a7955e4a8 Parents: b201f86 Author: Greg Hogan <[email protected]> Authored: Wed May 25 11:06:01 2016 -0400 Committer: Greg Hogan <[email protected]> Committed: Thu Jun 2 09:11:19 2016 -0400 ---------------------------------------------------------------------- docs/apis/batch/libs/gelly.md | 1 - .../graph/library/HITSAlgorithmITCase.java | 28 -------- .../flink/graph/library/PageRankITCase.java | 15 ++-- .../graph/gsa/GatherSumApplyIteration.java | 31 +++++++-- .../apache/flink/graph/library/GSAPageRank.java | 52 +++----------- .../flink/graph/library/HITSAlgorithm.java | 54 ++------------- .../apache/flink/graph/library/PageRank.java | 55 ++++----------- .../graph/spargel/ScatterGatherIteration.java | 73 ++++++++++++++------ .../apache/flink/graph/utils/GraphUtils.java | 58 ++++++++++++++++ 9 files changed, 171 insertions(+), 196 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/docs/apis/batch/libs/gelly.md ---------------------------------------------------------------------- diff --git a/docs/apis/batch/libs/gelly.md b/docs/apis/batch/libs/gelly.md index 0d3e594..aadbd44 100644 --- a/docs/apis/batch/libs/gelly.md +++ b/docs/apis/batch/libs/gelly.md @@ -1967,7 +1967,6 @@ The constructors take the following parameters: * `beta`: the damping factor. * `maxIterations`: the maximum number of iterations to run. -* `numVertices`: the number of vertices in the input. If known beforehand, is it advised to provide this argument to speed up execution. ### GSA PageRank http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/HITSAlgorithmITCase.java ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/HITSAlgorithmITCase.java b/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/HITSAlgorithmITCase.java index 019b851..1887725 100644 --- a/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/HITSAlgorithmITCase.java +++ b/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/HITSAlgorithmITCase.java @@ -56,20 +56,6 @@ public class HITSAlgorithmITCase extends MultipleProgramsTestBase{ } @Test - public void testHITSWithTenIterationsAndNumOfVertices() throws Exception { - final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); - - Graph<Long, Double, NullValue> graph = Graph.fromDataSet( - HITSData.getVertexDataSet(env), - HITSData.getEdgeDataSet(env), - env); - - List<Vertex<Long, Tuple2<DoubleValue, DoubleValue>>> result = graph.run(new HITSAlgorithm<Long, Double, NullValue>(10, 5)).collect(); - - compareWithDelta(result, 1e-7); - } - - @Test public void testHITSWithConvergeThreshold() throws Exception { final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); @@ -83,20 +69,6 @@ public class HITSAlgorithmITCase extends MultipleProgramsTestBase{ compareWithDelta(result, 1e-7); } - @Test - public void testHITSWithConvergeThresholdAndNumOfVertices() throws Exception { - final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); - - Graph<Long, Double, NullValue> graph = Graph.fromDataSet( - HITSData.getVertexDataSet(env), - HITSData.getEdgeDataSet(env), - env); - - List<Vertex<Long, Tuple2<DoubleValue, DoubleValue>>> result = graph.run(new HITSAlgorithm<Long, Double, NullValue>(1e-7, 5)).collect(); - - compareWithDelta(result, 1e-7); - } - private void compareWithDelta(List<Vertex<Long, Tuple2<DoubleValue, DoubleValue>>> result, double delta) { String resultString = ""; http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/PageRankITCase.java ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/PageRankITCase.java b/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/PageRankITCase.java index 034bcd5..e3e8f08 100644 --- a/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/PageRankITCase.java +++ b/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/PageRankITCase.java @@ -18,9 +18,6 @@ package org.apache.flink.graph.library; -import java.util.Arrays; -import java.util.List; - import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.graph.Graph; @@ -32,6 +29,9 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import java.util.Arrays; +import java.util.List; + @RunWith(Parameterized.class) public class PageRankITCase extends MultipleProgramsTestBase { @@ -72,9 +72,9 @@ public class PageRankITCase extends MultipleProgramsTestBase { Graph<Long, Double, Double> inputGraph = Graph.fromDataSet( PageRankData.getDefaultEdgeDataSet(env), new InitMapper(), env); - List<Vertex<Long, Double>> result = inputGraph.run(new PageRank<Long>(0.85, 5, 3)) + List<Vertex<Long, Double>> result = inputGraph.run(new PageRank<Long>(0.85, 3)) .collect(); - + compareWithDelta(result, 0.01); } @@ -85,14 +85,13 @@ public class PageRankITCase extends MultipleProgramsTestBase { Graph<Long, Double, Double> inputGraph = Graph.fromDataSet( PageRankData.getDefaultEdgeDataSet(env), new InitMapper(), env); - List<Vertex<Long, Double>> result = inputGraph.run(new GSAPageRank<Long>(0.85, 5, 3)) + List<Vertex<Long, Double>> result = inputGraph.run(new GSAPageRank<Long>(0.85, 3)) .collect(); compareWithDelta(result, 0.01); } - private void compareWithDelta(List<Vertex<Long, Double>> result, - double delta) { + private void compareWithDelta(List<Vertex<Long, Double>> result, double delta) { String resultString = ""; for (Vertex<Long, Double> v : result) { http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java index d092086..d1b12f9 100755 --- a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java +++ b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java @@ -41,9 +41,12 @@ import org.apache.flink.graph.Edge; import org.apache.flink.graph.EdgeDirection; import org.apache.flink.graph.Graph; import org.apache.flink.graph.Vertex; +import org.apache.flink.graph.utils.GraphUtils; +import org.apache.flink.types.LongValue; import org.apache.flink.util.Collector; import org.apache.flink.util.Preconditions; +import java.util.Collection; import java.util.Map; /** @@ -125,12 +128,11 @@ public class GatherSumApplyIteration<K, VV, EV, M> implements CustomUnaryOperati // check whether the numVertices option is set and, if so, compute the total number of vertices // and set it within the gather, sum and apply functions + + DataSet<LongValue> numberOfVertices = null; if (this.configuration != null && this.configuration.isOptNumVertices()) { try { - long numberOfVertices = graph.numberOfVertices(); - gather.setNumberOfVertices(numberOfVertices); - sum.setNumberOfVertices(numberOfVertices); - apply.setNumberOfVertices(numberOfVertices); + numberOfVertices = GraphUtils.count(this.vertexDataSet); } catch (Exception e) { e.printStackTrace(); } @@ -203,6 +205,9 @@ public class GatherSumApplyIteration<K, VV, EV, M> implements CustomUnaryOperati for (Tuple2<String, DataSet<?>> e : this.configuration.getGatherBcastVars()) { gatherMapOperator = gatherMapOperator.withBroadcastSet(e.f1, e.f0); } + if (this.configuration.isOptNumVertices()) { + gatherMapOperator = gatherMapOperator.withBroadcastSet(numberOfVertices, "number of vertices"); + } } DataSet<Tuple2<K, M>> gatheredSet = gatherMapOperator; @@ -215,6 +220,9 @@ public class GatherSumApplyIteration<K, VV, EV, M> implements CustomUnaryOperati for (Tuple2<String, DataSet<?>> e : this.configuration.getSumBcastVars()) { sumReduceOperator = sumReduceOperator.withBroadcastSet(e.f1, e.f0); } + if (this.configuration.isOptNumVertices()) { + sumReduceOperator = sumReduceOperator.withBroadcastSet(numberOfVertices, "number of vertices"); + } } DataSet<Tuple2<K, M>> summedSet = sumReduceOperator; @@ -231,6 +239,9 @@ public class GatherSumApplyIteration<K, VV, EV, M> implements CustomUnaryOperati for (Tuple2<String, DataSet<?>> e : this.configuration.getApplyBcastVars()) { appliedSet = appliedSet.withBroadcastSet(e.f1, e.f0); } + if (this.configuration.isOptNumVertices()) { + appliedSet = appliedSet.withBroadcastSet(numberOfVertices, "number of vertices"); + } } // let the operator know that we preserve the key field @@ -289,6 +300,10 @@ public class GatherSumApplyIteration<K, VV, EV, M> implements CustomUnaryOperati @Override public void open(Configuration parameters) throws Exception { + if (getRuntimeContext().hasBroadcastVariable("number of vertices")) { + Collection<LongValue> numberOfVertices = getRuntimeContext().getBroadcastVariable("number of vertices"); + this.gatherFunction.setNumberOfVertices(numberOfVertices.iterator().next().getValue()); + } if (getIterationRuntimeContext().getSuperstepNumber() == 1) { this.gatherFunction.init(getIterationRuntimeContext()); } @@ -327,6 +342,10 @@ public class GatherSumApplyIteration<K, VV, EV, M> implements CustomUnaryOperati @Override public void open(Configuration parameters) throws Exception { + if (getRuntimeContext().hasBroadcastVariable("number of vertices")) { + Collection<LongValue> numberOfVertices = getRuntimeContext().getBroadcastVariable("number of vertices"); + this.sumFunction.setNumberOfVertices(numberOfVertices.iterator().next().getValue()); + } if (getIterationRuntimeContext().getSuperstepNumber() == 1) { this.sumFunction.init(getIterationRuntimeContext()); } @@ -365,6 +384,10 @@ public class GatherSumApplyIteration<K, VV, EV, M> implements CustomUnaryOperati @Override public void open(Configuration parameters) throws Exception { + if (getRuntimeContext().hasBroadcastVariable("number of vertices")) { + Collection<LongValue> numberOfVertices = getRuntimeContext().getBroadcastVariable("number of vertices"); + this.applyFunction.setNumberOfVertices(numberOfVertices.iterator().next().getValue()); + } if (getIterationRuntimeContext().getSuperstepNumber() == 1) { this.applyFunction.init(getIterationRuntimeContext()); } http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/GSAPageRank.java ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/GSAPageRank.java b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/GSAPageRank.java index 99624ca..324f9c3 100644 --- a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/GSAPageRank.java +++ b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/GSAPageRank.java @@ -25,6 +25,7 @@ import org.apache.flink.graph.Graph; import org.apache.flink.graph.GraphAlgorithm; import org.apache.flink.graph.Vertex; import org.apache.flink.graph.gsa.ApplyFunction; +import org.apache.flink.graph.gsa.GSAConfiguration; import org.apache.flink.graph.gsa.GatherFunction; import org.apache.flink.graph.gsa.Neighbor; import org.apache.flink.graph.gsa.SumFunction; @@ -32,22 +33,17 @@ import org.apache.flink.graph.gsa.SumFunction; /** * This is an implementation of a simple PageRank algorithm, using a gather-sum-apply iteration. * The user can define the damping factor and the maximum number of iterations. - * If the number of vertices of the input graph is known, it should be provided as a parameter - * to speed up computation. Otherwise, the algorithm will first execute a job to count the vertices. - * + * * The implementation assumes that each page has at least one incoming and one outgoing link. */ public class GSAPageRank<K> implements GraphAlgorithm<K, Double, Double, DataSet<Vertex<K, Double>>> { private double beta; private int maxIterations; - private long numberOfVertices; /** * Creates an instance of the GSA PageRank algorithm. - * If the number of vertices of the input graph is known, - * use the {@link GSAPageRank#GSAPageRank(double, long, int)} constructor instead. - * + * * The implementation assumes that each page has at least one incoming and one outgoing link. * * @param beta the damping factor @@ -58,37 +54,19 @@ public class GSAPageRank<K> implements GraphAlgorithm<K, Double, Double, DataSet this.maxIterations = maxIterations; } - /** - * Creates an instance of the GSA PageRank algorithm. - * If the number of vertices of the input graph is known, - * use the {@link GSAPageRank#GSAPageRank(double, int)} constructor instead. - * - * The implementation assumes that each page has at least one incoming and one outgoing link. - * - * @param beta the damping factor - * @param maxIterations the maximum number of iterations - * @param numVertices the number of vertices in the input - */ - public GSAPageRank(double beta, long numVertices, int maxIterations) { - this.beta = beta; - this.numberOfVertices = numVertices; - this.maxIterations = maxIterations; - } - @Override public DataSet<Vertex<K, Double>> run(Graph<K, Double, Double> network) throws Exception { - if (numberOfVertices == 0) { - numberOfVertices = network.numberOfVertices(); - } - DataSet<Tuple2<K, Long>> vertexOutDegrees = network.outDegrees(); Graph<K, Double, Double> networkWithWeights = network .joinWithEdgesOnSource(vertexOutDegrees, new InitWeights()); - return networkWithWeights.runGatherSumApplyIteration(new GatherRanks(numberOfVertices), new SumRanks(), - new UpdateRanks<K>(beta, numberOfVertices), maxIterations) + GSAConfiguration parameters = new GSAConfiguration(); + parameters.setOptNumVertices(true); + + return networkWithWeights.runGatherSumApplyIteration(new GatherRanks(), new SumRanks(), + new UpdateRanks<K>(beta), maxIterations, parameters) .getVertices(); } @@ -99,18 +77,12 @@ public class GSAPageRank<K> implements GraphAlgorithm<K, Double, Double, DataSet @SuppressWarnings("serial") private static final class GatherRanks extends GatherFunction<Double, Double, Double> { - long numberOfVertices; - - public GatherRanks(long numberOfVertices) { - this.numberOfVertices = numberOfVertices; - } - @Override public Double gather(Neighbor<Double, Double> neighbor) { double neighborRank = neighbor.getNeighborValue(); if(getSuperstepNumber() == 1) { - neighborRank = 1.0 / numberOfVertices; + neighborRank = 1.0 / this.getNumberOfVertices(); } return neighborRank * neighbor.getEdgeValue(); @@ -130,16 +102,14 @@ public class GSAPageRank<K> implements GraphAlgorithm<K, Double, Double, DataSet private static final class UpdateRanks<K> extends ApplyFunction<K, Double, Double> { private final double beta; - private final long numVertices; - public UpdateRanks(double beta, long numberOfVertices) { + public UpdateRanks(double beta) { this.beta = beta; - this.numVertices = numberOfVertices; } @Override public void apply(Double rankSum, Double currentValue) { - setResult((1-beta)/numVertices + beta * rankSum); + setResult((1-beta)/this.getNumberOfVertices() + beta * rankSum); } } http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/HITSAlgorithm.java ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/HITSAlgorithm.java b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/HITSAlgorithm.java index 1ea367e..39e9487 100644 --- a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/HITSAlgorithm.java +++ b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/HITSAlgorithm.java @@ -42,8 +42,6 @@ import org.apache.flink.util.Preconditions; * represented a page that is linked by many different hubs. * Each vertex has a value of Tuple2 type, the first field is hub score and the second field is authority score. * The implementation sets same score to every vertex and adds the reverse edge to every edge at the beginning. - * If the number of vertices of the input graph is known, it should be provided as a parameter - * to speed up computation. Otherwise, the algorithm will first execute a job to count the vertices. * <p> * * @see <a href="https://en.wikipedia.org/wiki/HITS_algorithm">HITS Algorithm</a> @@ -54,7 +52,6 @@ public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K, VV, EV, DataS private final static double MINIMUMTHRESHOLD = 1e-9; private int maxIterations; - private long numberOfVertices; private double convergeThreshold; /** @@ -76,26 +73,6 @@ public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K, VV, EV, DataS } /** - * Create an instance of HITS algorithm. - * - * @param maxIterations the maximum number of iterations - * @param numberOfVertices the number of vertices in the graph - */ - public HITSAlgorithm(int maxIterations, long numberOfVertices) { - this(maxIterations, MINIMUMTHRESHOLD, numberOfVertices); - } - - /** - * Create an instance of HITS algorithm. - * - * @param convergeThreshold convergence threshold for sum of scores to control whether the iteration should be stopped - * @param numberOfVertices the number of vertices in the graph - */ - public HITSAlgorithm(double convergeThreshold, long numberOfVertices) { - this(MAXIMUMITERATION, convergeThreshold, numberOfVertices); - } - - /** * Creates an instance of HITS algorithm. * * @param maxIterations the maximum number of iterations @@ -108,26 +85,8 @@ public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K, VV, EV, DataS this.convergeThreshold = convergeThreshold; } - /** - * Creates an instance of HITS algorithm. - * - * @param maxIterations the maximum number of iterations - * @param convergeThreshold convergence threshold for sum of scores to control whether the iteration should be stopped - * @param numberOfVertices the number of vertices in the graph - */ - public HITSAlgorithm(int maxIterations, double convergeThreshold, long numberOfVertices) { - this(maxIterations, convergeThreshold); - Preconditions.checkArgument(numberOfVertices > 0, "Number of vertices must be greater than zero."); - this.numberOfVertices = numberOfVertices; - } - @Override public DataSet<Vertex<K, Tuple2<DoubleValue, DoubleValue>>> run(Graph<K, VV, EV> graph) throws Exception { - - if (numberOfVertices == 0) { - numberOfVertices = graph.numberOfVertices(); - } - Graph<K, Tuple2<DoubleValue, DoubleValue>, Boolean> newGraph = graph .mapEdges(new AuthorityEdgeMapper<K, EV>()) .union(graph.reverse().mapEdges(new HubEdgeMapper<K, EV>())) @@ -135,12 +94,13 @@ public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K, VV, EV, DataS ScatterGatherConfiguration parameter = new ScatterGatherConfiguration(); parameter.setDirection(EdgeDirection.OUT); + parameter.setOptNumVertices(true); parameter.registerAggregator("updatedValueSum", new DoubleSumAggregator()); parameter.registerAggregator("authorityValueSum", new DoubleSumAggregator()); parameter.registerAggregator("diffValueSum", new DoubleSumAggregator()); return newGraph - .runScatterGatherIteration(new VertexUpdate<K>(maxIterations, convergeThreshold, numberOfVertices), + .runScatterGatherIteration(new VertexUpdate<K>(maxIterations, convergeThreshold), new MessageUpdate<K>(maxIterations), maxIterations, parameter) .getVertices(); } @@ -153,15 +113,13 @@ public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K, VV, EV, DataS public static final class VertexUpdate<K> extends VertexUpdateFunction<K, Tuple2<DoubleValue, DoubleValue>, Double> { private int maxIteration; private double convergeThreshold; - private long numberOfVertices; private DoubleSumAggregator updatedValueSumAggregator; private DoubleSumAggregator authoritySumAggregator; private DoubleSumAggregator diffSumAggregator; - public VertexUpdate(int maxIteration, double convergeThreshold, long numberOfVertices) { + public VertexUpdate(int maxIteration, double convergeThreshold) { this.maxIteration = maxIteration; this.convergeThreshold = convergeThreshold; - this.numberOfVertices = numberOfVertices; } @Override @@ -198,9 +156,9 @@ public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K, VV, EV, DataS //in the first iteration, the diff is the authority value of each vertex double previousAuthAverage = 1.0; - double diffValueSum = 1.0 * numberOfVertices; + double diffValueSum = 1.0 * getNumberOfVertices(); if (getSuperstepNumber() > 1) { - previousAuthAverage = ((DoubleValue) getPreviousIterationAggregate("authorityValueSum")).getValue() / numberOfVertices; + previousAuthAverage = ((DoubleValue) getPreviousIterationAggregate("authorityValueSum")).getValue() / getNumberOfVertices(); diffValueSum = ((DoubleValue) getPreviousIterationAggregate("diffValueSum")).getValue(); } authoritySumAggregator.aggregate(previousAuthAverage); @@ -218,7 +176,7 @@ public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K, VV, EV, DataS newHubValue.setValue(updateValue); newAuthorityValue.setValue(newAuthorityValue.getValue() / iterationValueSum); authoritySumAggregator.aggregate(newAuthorityValue.getValue()); - double previousAuthAverage = ((DoubleValue) getPreviousIterationAggregate("authorityValueSum")).getValue() / numberOfVertices; + double previousAuthAverage = ((DoubleValue) getPreviousIterationAggregate("authorityValueSum")).getValue() / getNumberOfVertices(); // count the diff value of sum of authority scores diffSumAggregator.aggregate((previousAuthAverage - newAuthorityValue.getValue())); http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/PageRank.java ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/PageRank.java b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/PageRank.java index 9890a7c..f83b05b 100644 --- a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/PageRank.java +++ b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/PageRank.java @@ -27,27 +27,23 @@ import org.apache.flink.graph.GraphAlgorithm; import org.apache.flink.graph.Vertex; import org.apache.flink.graph.spargel.MessageIterator; import org.apache.flink.graph.spargel.MessagingFunction; +import org.apache.flink.graph.spargel.ScatterGatherConfiguration; import org.apache.flink.graph.spargel.VertexUpdateFunction; /** * This is an implementation of a simple PageRank algorithm, using a scatter-gather iteration. * The user can define the damping factor and the maximum number of iterations. - * If the number of vertices of the input graph is known, it should be provided as a parameter - * to speed up computation. Otherwise, the algorithm will first execute a job to count the vertices. - * + * * The implementation assumes that each page has at least one incoming and one outgoing link. */ public class PageRank<K> implements GraphAlgorithm<K, Double, Double, DataSet<Vertex<K, Double>>> { private double beta; private int maxIterations; - private long numberOfVertices; /** * Creates an instance of the PageRank algorithm. - * If the number of vertices of the input graph is known, - * use the {@link PageRank#PageRank(double, long, int)} constructor instead. - * + * * The implementation assumes that each page has at least one incoming and one outgoing link. * * @param beta the damping factor @@ -56,40 +52,21 @@ public class PageRank<K> implements GraphAlgorithm<K, Double, Double, DataSet<Ve public PageRank(double beta, int maxIterations) { this.beta = beta; this.maxIterations = maxIterations; - this.numberOfVertices = 0; - } - - /** - * Creates an instance of the PageRank algorithm. - * If the number of vertices of the input graph is known, - * use the {@link PageRank#PageRank(double, int)} constructor instead. - * - * The implementation assumes that each page has at least one incoming and one outgoing link. - * - * @param beta the damping factor - * @param maxIterations the maximum number of iterations - * @param numVertices the number of vertices in the input - */ - public PageRank(double beta, long numVertices, int maxIterations) { - this.beta = beta; - this.maxIterations = maxIterations; - this.numberOfVertices = numVertices; } @Override public DataSet<Vertex<K, Double>> run(Graph<K, Double, Double> network) throws Exception { - if (numberOfVertices == 0) { - numberOfVertices = network.numberOfVertices(); - } - DataSet<Tuple2<K, Long>> vertexOutDegrees = network.outDegrees(); Graph<K, Double, Double> networkWithWeights = network .joinWithEdgesOnSource(vertexOutDegrees, new InitWeights()); - return networkWithWeights.runScatterGatherIteration(new VertexRankUpdater<K>(beta, numberOfVertices), - new RankMessenger<K>(numberOfVertices), maxIterations) + ScatterGatherConfiguration parameters = new ScatterGatherConfiguration(); + parameters.setOptNumVertices(true); + + return networkWithWeights.runScatterGatherIteration(new VertexRankUpdater<K>(beta), + new RankMessenger<K>(), maxIterations, parameters) .getVertices(); } @@ -101,11 +78,9 @@ public class PageRank<K> implements GraphAlgorithm<K, Double, Double, DataSet<Ve public static final class VertexRankUpdater<K> extends VertexUpdateFunction<K, Double, Double> { private final double beta; - private final long numVertices; - - public VertexRankUpdater(double beta, long numberOfVertices) { + + public VertexRankUpdater(double beta) { this.beta = beta; - this.numVertices = numberOfVertices; } @Override @@ -116,7 +91,7 @@ public class PageRank<K> implements GraphAlgorithm<K, Double, Double, DataSet<Ve } // apply the dampening factor / random jump - double newRank = (beta * rankSum) + (1 - beta) / numVertices; + double newRank = (beta * rankSum) + (1 - beta) / this.getNumberOfVertices(); setNewVertexValue(newRank); } } @@ -129,17 +104,11 @@ public class PageRank<K> implements GraphAlgorithm<K, Double, Double, DataSet<Ve @SuppressWarnings("serial") public static final class RankMessenger<K> extends MessagingFunction<K, Double, Double, Double> { - private final long numVertices; - - public RankMessenger(long numberOfVertices) { - this.numVertices = numberOfVertices; - } - @Override public void sendMessages(Vertex<K, Double> vertex) { if (getSuperstepNumber() == 1) { // initialize vertex ranks - vertex.setValue(new Double(1.0 / numVertices)); + vertex.setValue(1.0 / this.getNumberOfVertices()); } for (Edge<K, Double> edge : getEdges()) { http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/spargel/ScatterGatherIteration.java ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/spargel/ScatterGatherIteration.java b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/spargel/ScatterGatherIteration.java index 496e36d..165ef1e 100644 --- a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/spargel/ScatterGatherIteration.java +++ b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/spargel/ScatterGatherIteration.java @@ -18,18 +18,15 @@ package org.apache.flink.graph.spargel; -import java.util.Iterator; -import java.util.Map; - import org.apache.flink.api.common.aggregators.Aggregator; import org.apache.flink.api.common.functions.FlatJoinFunction; import org.apache.flink.api.common.functions.MapFunction; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.operators.DeltaIteration; import org.apache.flink.api.common.functions.RichCoGroupFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.operators.CoGroupOperator; import org.apache.flink.api.java.operators.CustomUnaryOperation; +import org.apache.flink.api.java.operators.DeltaIteration; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.typeutils.ResultTypeQueryable; @@ -40,9 +37,15 @@ import org.apache.flink.graph.Edge; import org.apache.flink.graph.EdgeDirection; import org.apache.flink.graph.Graph; import org.apache.flink.graph.Vertex; +import org.apache.flink.graph.utils.GraphUtils; +import org.apache.flink.types.LongValue; import org.apache.flink.util.Collector; import org.apache.flink.util.Preconditions; +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; + /** * This class represents iterative graph computations, programmed in a scatter-gather perspective. * It is a special case of <i>Bulk Synchronous Parallel</i> computation. @@ -151,11 +154,10 @@ public class ScatterGatherIteration<K, VV, Message, EV> // check whether the numVertices option is set and, if so, compute the total number of vertices // and set it within the messaging and update functions + DataSet<LongValue> numberOfVertices = null; if (this.configuration != null && this.configuration.isOptNumVertices()) { try { - long numberOfVertices = graph.numberOfVertices(); - messagingFunction.setNumberOfVertices(numberOfVertices); - updateFunction.setNumberOfVertices(numberOfVertices); + numberOfVertices = GraphUtils.count(this.initialVertices); } catch (Exception e) { e.printStackTrace(); } @@ -173,9 +175,9 @@ public class ScatterGatherIteration<K, VV, Message, EV> // check whether the degrees option is set and, if so, compute the in and the out degrees and // add them to the vertex value if(this.configuration != null && this.configuration.isOptDegrees()) { - return createResultVerticesWithDegrees(graph, messagingDirection, messageTypeInfo); + return createResultVerticesWithDegrees(graph, messagingDirection, messageTypeInfo, numberOfVertices); } else { - return createResultSimpleVertex(messagingDirection, messageTypeInfo); + return createResultSimpleVertex(messagingDirection, messageTypeInfo, numberOfVertices); } } @@ -246,6 +248,10 @@ public class ScatterGatherIteration<K, VV, Message, EV> @Override public void open(Configuration parameters) throws Exception { + if (getRuntimeContext().hasBroadcastVariable("number of vertices")) { + Collection<LongValue> numberOfVertices = getRuntimeContext().getBroadcastVariable("number of vertices"); + this.vertexUpdateFunction.setNumberOfVertices(numberOfVertices.iterator().next().getValue()); + } if (getIterationRuntimeContext().getSuperstepNumber() == 1) { this.vertexUpdateFunction.init(getIterationRuntimeContext()); } @@ -368,10 +374,13 @@ public class ScatterGatherIteration<K, VV, Message, EV> @Override public void open(Configuration parameters) throws Exception { + if (getRuntimeContext().hasBroadcastVariable("number of vertices")) { + Collection<LongValue> numberOfVertices = getRuntimeContext().getBroadcastVariable("number of vertices"); + this.messagingFunction.setNumberOfVertices(numberOfVertices.iterator().next().getValue()); + } if (getIterationRuntimeContext().getSuperstepNumber() == 1) { this.messagingFunction.init(getIterationRuntimeContext()); } - this.messagingFunction.preSuperstep(); } @@ -459,7 +468,8 @@ public class ScatterGatherIteration<K, VV, Message, EV> */ private CoGroupOperator<?, ?, Tuple2<K, Message>> buildMessagingFunction( DeltaIteration<Vertex<K, VV>, Vertex<K, VV>> iteration, - TypeInformation<Tuple2<K, Message>> messageTypeInfo, int whereArg, int equalToArg) { + TypeInformation<Tuple2<K, Message>> messageTypeInfo, int whereArg, int equalToArg, + DataSet<LongValue> numberOfVertices) { // build the messaging function (co group) CoGroupOperator<?, ?, Tuple2<K, Message>> messages; @@ -475,6 +485,9 @@ public class ScatterGatherIteration<K, VV, Message, EV> for (Tuple2<String, DataSet<?>> e : this.configuration.getMessagingBcastVars()) { messages = messages.withBroadcastSet(e.f1, e.f0); } + if (this.configuration.isOptNumVertices()) { + messages = messages.withBroadcastSet(numberOfVertices, "number of vertices"); + } } return messages; @@ -493,7 +506,8 @@ public class ScatterGatherIteration<K, VV, Message, EV> */ private CoGroupOperator<?, ?, Tuple2<K, Message>> buildMessagingFunctionVerticesWithDegrees( DeltaIteration<Vertex<K, Tuple3<VV, Long, Long>>, Vertex<K, Tuple3<VV, Long, Long>>> iteration, - TypeInformation<Tuple2<K, Message>> messageTypeInfo, int whereArg, int equalToArg) { + TypeInformation<Tuple2<K, Message>> messageTypeInfo, int whereArg, int equalToArg, + DataSet<LongValue> numberOfVertices) { // build the messaging function (co group) CoGroupOperator<?, ?, Tuple2<K, Message>> messages; @@ -510,6 +524,9 @@ public class ScatterGatherIteration<K, VV, Message, EV> for (Tuple2<String, DataSet<?>> e : this.configuration.getMessagingBcastVars()) { messages = messages.withBroadcastSet(e.f1, e.f0); } + if (this.configuration.isOptNumVertices()) { + messages = messages.withBroadcastSet(numberOfVertices, "number of vertices"); + } } return messages; @@ -546,10 +563,11 @@ public class ScatterGatherIteration<K, VV, Message, EV> * * @param messagingDirection * @param messageTypeInfo + * @param numberOfVertices * @return the operator */ private DataSet<Vertex<K, VV>> createResultSimpleVertex(EdgeDirection messagingDirection, - TypeInformation<Tuple2<K, Message>> messageTypeInfo) { + TypeInformation<Tuple2<K, Message>> messageTypeInfo, DataSet<LongValue> numberOfVertices) { DataSet<Tuple2<K, Message>> messages; @@ -561,14 +579,14 @@ public class ScatterGatherIteration<K, VV, Message, EV> switch (messagingDirection) { case IN: - messages = buildMessagingFunction(iteration, messageTypeInfo, 1, 0); + messages = buildMessagingFunction(iteration, messageTypeInfo, 1, 0, numberOfVertices); break; case OUT: - messages = buildMessagingFunction(iteration, messageTypeInfo, 0, 0); + messages = buildMessagingFunction(iteration, messageTypeInfo, 0, 0, numberOfVertices); break; case ALL: - messages = buildMessagingFunction(iteration, messageTypeInfo, 1, 0) - .union(buildMessagingFunction(iteration, messageTypeInfo, 0, 0)) ; + messages = buildMessagingFunction(iteration, messageTypeInfo, 1, 0, numberOfVertices) + .union(buildMessagingFunction(iteration, messageTypeInfo, 0, 0, numberOfVertices)) ; break; default: throw new IllegalArgumentException("Illegal edge direction"); @@ -581,6 +599,10 @@ public class ScatterGatherIteration<K, VV, Message, EV> CoGroupOperator<?, ?, Vertex<K, VV>> updates = messages.coGroup(iteration.getSolutionSet()).where(0).equalTo(0).with(updateUdf); + if (this.configuration != null && this.configuration.isOptNumVertices()) { + updates = updates.withBroadcastSet(numberOfVertices, "number of vertices"); + } + configureUpdateFunction(updates); return iteration.closeWith(updates, updates); @@ -593,11 +615,12 @@ public class ScatterGatherIteration<K, VV, Message, EV> * @param graph * @param messagingDirection * @param messageTypeInfo + * @param numberOfVertices * @return the operator */ @SuppressWarnings("serial") private DataSet<Vertex<K, VV>> createResultVerticesWithDegrees(Graph<K, VV, EV> graph, EdgeDirection messagingDirection, - TypeInformation<Tuple2<K, Message>> messageTypeInfo) { + TypeInformation<Tuple2<K, Message>> messageTypeInfo, DataSet<LongValue> numberOfVertices) { DataSet<Tuple2<K, Message>> messages; @@ -636,14 +659,14 @@ public class ScatterGatherIteration<K, VV, Message, EV> switch (messagingDirection) { case IN: - messages = buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 1, 0); + messages = buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 1, 0, numberOfVertices); break; case OUT: - messages = buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 0, 0); + messages = buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 0, 0, numberOfVertices); break; case ALL: - messages = buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 1, 0) - .union(buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 0, 0)) ; + messages = buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 1, 0, numberOfVertices) + .union(buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 0, 0, numberOfVertices)) ; break; default: throw new IllegalArgumentException("Illegal edge direction"); @@ -657,6 +680,10 @@ public class ScatterGatherIteration<K, VV, Message, EV> CoGroupOperator<?, ?, Vertex<K, Tuple3<VV, Long, Long>>> updates = messages.coGroup(iteration.getSolutionSet()).where(0).equalTo(0).with(updateUdf); + if (this.configuration != null && this.configuration.isOptNumVertices()) { + updates = updates.withBroadcastSet(numberOfVertices, "number of vertices"); + } + configureUpdateFunction(updates); return iteration.closeWith(updates, updates).map( http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/utils/GraphUtils.java ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/utils/GraphUtils.java b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/utils/GraphUtils.java index 009d791..264479b 100644 --- a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/utils/GraphUtils.java +++ b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/utils/GraphUtils.java @@ -19,12 +19,18 @@ package org.apache.flink.graph.utils; import org.apache.flink.api.common.JobExecutionResult; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.Utils; import org.apache.flink.graph.Edge; import org.apache.flink.graph.Graph; import org.apache.flink.graph.Vertex; +import org.apache.flink.types.LongValue; import org.apache.flink.util.AbstractID; +import static org.apache.flink.api.java.typeutils.ValueTypeInfo.LONG_VALUE_TYPE_INFO; + public class GraphUtils { /** @@ -50,4 +56,56 @@ public class GraphUtils { return checksum; } + + /** + * Count the number of elements in a DataSet. + * + * @param input DataSet of elements to be counted + * @param <T> element type + * @return count + */ + public static <T> DataSet<LongValue> count(DataSet<T> input) { + return input + .map(new MapTo<T, LongValue>(new LongValue(1))) + .returns(LONG_VALUE_TYPE_INFO) + .reduce(new AddLongValue()); + } + + /** + * Map each element to a value. + * + * @param <I> input type + * @param <O> output type + */ + public static class MapTo<I, O> + implements MapFunction<I, O> { + private final O value; + + /** + * Map each element to the given object. + * + * @param value the object to emit for each element + */ + public MapTo(O value) { + this.value = value; + } + + @Override + public O map(I o) throws Exception { + return value; + } + } + + /** + * Add {@link LongValue} elements. + */ + public static class AddLongValue + implements ReduceFunction<LongValue> { + @Override + public LongValue reduce(LongValue value1, LongValue value2) + throws Exception { + value1.setValue(value1.getValue() + value2.getValue()); + return value1; + } + } }
