Repository: flink
Updated Branches:
  refs/heads/master b201f8664 -> 65545c2ed


[FLINK-3806] [gelly] Revert use of DataSet.count()

This closes #2036


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/36ad78c0
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/36ad78c0
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/36ad78c0

Branch: refs/heads/master
Commit: 36ad78c0821fdae0a69371c67602dd2a7955e4a8
Parents: b201f86
Author: Greg Hogan <[email protected]>
Authored: Wed May 25 11:06:01 2016 -0400
Committer: Greg Hogan <[email protected]>
Committed: Thu Jun 2 09:11:19 2016 -0400

----------------------------------------------------------------------
 docs/apis/batch/libs/gelly.md                   |  1 -
 .../graph/library/HITSAlgorithmITCase.java      | 28 --------
 .../flink/graph/library/PageRankITCase.java     | 15 ++--
 .../graph/gsa/GatherSumApplyIteration.java      | 31 +++++++--
 .../apache/flink/graph/library/GSAPageRank.java | 52 +++-----------
 .../flink/graph/library/HITSAlgorithm.java      | 54 ++-------------
 .../apache/flink/graph/library/PageRank.java    | 55 ++++-----------
 .../graph/spargel/ScatterGatherIteration.java   | 73 ++++++++++++++------
 .../apache/flink/graph/utils/GraphUtils.java    | 58 ++++++++++++++++
 9 files changed, 171 insertions(+), 196 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/docs/apis/batch/libs/gelly.md
----------------------------------------------------------------------
diff --git a/docs/apis/batch/libs/gelly.md b/docs/apis/batch/libs/gelly.md
index 0d3e594..aadbd44 100644
--- a/docs/apis/batch/libs/gelly.md
+++ b/docs/apis/batch/libs/gelly.md
@@ -1967,7 +1967,6 @@ The constructors take the following parameters:
 
 * `beta`: the damping factor.
 * `maxIterations`: the maximum number of iterations to run.
-* `numVertices`: the number of vertices in the input. If known beforehand, is 
it advised to provide this argument to speed up execution.
 
 ### GSA PageRank
 

http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/HITSAlgorithmITCase.java
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/HITSAlgorithmITCase.java
 
b/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/HITSAlgorithmITCase.java
index 019b851..1887725 100644
--- 
a/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/HITSAlgorithmITCase.java
+++ 
b/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/HITSAlgorithmITCase.java
@@ -56,20 +56,6 @@ public class HITSAlgorithmITCase extends 
MultipleProgramsTestBase{
        }
 
        @Test
-       public void testHITSWithTenIterationsAndNumOfVertices() throws 
Exception {
-               final ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
-
-               Graph<Long, Double, NullValue> graph = Graph.fromDataSet(
-                               HITSData.getVertexDataSet(env),
-                               HITSData.getEdgeDataSet(env),
-                               env);
-
-               List<Vertex<Long, Tuple2<DoubleValue, DoubleValue>>> result = 
graph.run(new HITSAlgorithm<Long, Double, NullValue>(10, 5)).collect();
-               
-               compareWithDelta(result, 1e-7);
-       }
-
-       @Test
        public void testHITSWithConvergeThreshold() throws Exception {
                final ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
 
@@ -83,20 +69,6 @@ public class HITSAlgorithmITCase extends 
MultipleProgramsTestBase{
                compareWithDelta(result, 1e-7);
        }
 
-       @Test
-       public void testHITSWithConvergeThresholdAndNumOfVertices() throws 
Exception {
-               final ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
-
-               Graph<Long, Double, NullValue> graph = Graph.fromDataSet(
-                               HITSData.getVertexDataSet(env),
-                               HITSData.getEdgeDataSet(env),
-                               env);
-
-               List<Vertex<Long, Tuple2<DoubleValue, DoubleValue>>> result = 
graph.run(new HITSAlgorithm<Long, Double, NullValue>(1e-7, 5)).collect();
-
-               compareWithDelta(result, 1e-7);
-       }
-
        private void compareWithDelta(List<Vertex<Long, Tuple2<DoubleValue, 
DoubleValue>>> result, double delta) {
 
                String resultString = "";

http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/PageRankITCase.java
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/PageRankITCase.java
 
b/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/PageRankITCase.java
index 034bcd5..e3e8f08 100644
--- 
a/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/PageRankITCase.java
+++ 
b/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/PageRankITCase.java
@@ -18,9 +18,6 @@
 
 package org.apache.flink.graph.library;
 
-import java.util.Arrays;
-import java.util.List;
-
 import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.java.ExecutionEnvironment;
 import org.apache.flink.graph.Graph;
@@ -32,6 +29,9 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
+import java.util.Arrays;
+import java.util.List;
+
 @RunWith(Parameterized.class)
 public class PageRankITCase extends MultipleProgramsTestBase {
 
@@ -72,9 +72,9 @@ public class PageRankITCase extends MultipleProgramsTestBase {
                Graph<Long, Double, Double> inputGraph = Graph.fromDataSet(
                                PageRankData.getDefaultEdgeDataSet(env), new 
InitMapper(), env);
 
-        List<Vertex<Long, Double>> result = inputGraph.run(new 
PageRank<Long>(0.85, 5, 3))
+        List<Vertex<Long, Double>> result = inputGraph.run(new 
PageRank<Long>(0.85, 3))
                        .collect();
-        
+
         compareWithDelta(result, 0.01);
        }
 
@@ -85,14 +85,13 @@ public class PageRankITCase extends 
MultipleProgramsTestBase {
                Graph<Long, Double, Double> inputGraph = Graph.fromDataSet(
                                PageRankData.getDefaultEdgeDataSet(env), new 
InitMapper(), env);
 
-        List<Vertex<Long, Double>> result = inputGraph.run(new 
GSAPageRank<Long>(0.85, 5, 3))
+        List<Vertex<Long, Double>> result = inputGraph.run(new 
GSAPageRank<Long>(0.85, 3))
                        .collect();
         
         compareWithDelta(result, 0.01);
        }
 
-       private void compareWithDelta(List<Vertex<Long, Double>> result,
-                                                                               
                                                double delta) {
+       private void compareWithDelta(List<Vertex<Long, Double>> result, double 
delta) {
 
                String resultString = "";
         for (Vertex<Long, Double> v : result) {

http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java
 
b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java
index d092086..d1b12f9 100755
--- 
a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java
+++ 
b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java
@@ -41,9 +41,12 @@ import org.apache.flink.graph.Edge;
 import org.apache.flink.graph.EdgeDirection;
 import org.apache.flink.graph.Graph;
 import org.apache.flink.graph.Vertex;
+import org.apache.flink.graph.utils.GraphUtils;
+import org.apache.flink.types.LongValue;
 import org.apache.flink.util.Collector;
 import org.apache.flink.util.Preconditions;
 
+import java.util.Collection;
 import java.util.Map;
 
 /**
@@ -125,12 +128,11 @@ public class GatherSumApplyIteration<K, VV, EV, M> 
implements CustomUnaryOperati
 
                // check whether the numVertices option is set and, if so, 
compute the total number of vertices
                // and set it within the gather, sum and apply functions
+
+               DataSet<LongValue> numberOfVertices = null;
                if (this.configuration != null && 
this.configuration.isOptNumVertices()) {
                        try {
-                               long numberOfVertices = 
graph.numberOfVertices();
-                               gather.setNumberOfVertices(numberOfVertices);
-                               sum.setNumberOfVertices(numberOfVertices);
-                               apply.setNumberOfVertices(numberOfVertices);
+                               numberOfVertices = 
GraphUtils.count(this.vertexDataSet);
                        } catch (Exception e) {
                                e.printStackTrace();
                        }
@@ -203,6 +205,9 @@ public class GatherSumApplyIteration<K, VV, EV, M> 
implements CustomUnaryOperati
                        for (Tuple2<String, DataSet<?>> e : 
this.configuration.getGatherBcastVars()) {
                                gatherMapOperator = 
gatherMapOperator.withBroadcastSet(e.f1, e.f0);
                        }
+                       if (this.configuration.isOptNumVertices()) {
+                               gatherMapOperator = 
gatherMapOperator.withBroadcastSet(numberOfVertices, "number of vertices");
+                       }
                }
                DataSet<Tuple2<K, M>> gatheredSet = gatherMapOperator;
 
@@ -215,6 +220,9 @@ public class GatherSumApplyIteration<K, VV, EV, M> 
implements CustomUnaryOperati
                        for (Tuple2<String, DataSet<?>> e : 
this.configuration.getSumBcastVars()) {
                                sumReduceOperator = 
sumReduceOperator.withBroadcastSet(e.f1, e.f0);
                        }
+                       if (this.configuration.isOptNumVertices()) {
+                               sumReduceOperator = 
sumReduceOperator.withBroadcastSet(numberOfVertices, "number of vertices");
+                       }
                }
                DataSet<Tuple2<K, M>> summedSet = sumReduceOperator;
 
@@ -231,6 +239,9 @@ public class GatherSumApplyIteration<K, VV, EV, M> 
implements CustomUnaryOperati
                        for (Tuple2<String, DataSet<?>> e : 
this.configuration.getApplyBcastVars()) {
                                appliedSet = appliedSet.withBroadcastSet(e.f1, 
e.f0);
                        }
+                       if (this.configuration.isOptNumVertices()) {
+                               appliedSet = 
appliedSet.withBroadcastSet(numberOfVertices, "number of vertices");
+                       }
                }
 
                // let the operator know that we preserve the key field
@@ -289,6 +300,10 @@ public class GatherSumApplyIteration<K, VV, EV, M> 
implements CustomUnaryOperati
 
                @Override
                public void open(Configuration parameters) throws Exception {
+                       if (getRuntimeContext().hasBroadcastVariable("number of 
vertices")) {
+                               Collection<LongValue> numberOfVertices = 
getRuntimeContext().getBroadcastVariable("number of vertices");
+                               
this.gatherFunction.setNumberOfVertices(numberOfVertices.iterator().next().getValue());
+                       }
                        if (getIterationRuntimeContext().getSuperstepNumber() 
== 1) {
                                
this.gatherFunction.init(getIterationRuntimeContext());
                        }
@@ -327,6 +342,10 @@ public class GatherSumApplyIteration<K, VV, EV, M> 
implements CustomUnaryOperati
 
                @Override
                public void open(Configuration parameters) throws Exception {
+                       if (getRuntimeContext().hasBroadcastVariable("number of 
vertices")) {
+                               Collection<LongValue> numberOfVertices = 
getRuntimeContext().getBroadcastVariable("number of vertices");
+                               
this.sumFunction.setNumberOfVertices(numberOfVertices.iterator().next().getValue());
+                       }
                        if (getIterationRuntimeContext().getSuperstepNumber() 
== 1) {
                                
this.sumFunction.init(getIterationRuntimeContext());
                        }
@@ -365,6 +384,10 @@ public class GatherSumApplyIteration<K, VV, EV, M> 
implements CustomUnaryOperati
 
                @Override
                public void open(Configuration parameters) throws Exception {
+                       if (getRuntimeContext().hasBroadcastVariable("number of 
vertices")) {
+                               Collection<LongValue> numberOfVertices = 
getRuntimeContext().getBroadcastVariable("number of vertices");
+                               
this.applyFunction.setNumberOfVertices(numberOfVertices.iterator().next().getValue());
+                       }
                        if (getIterationRuntimeContext().getSuperstepNumber() 
== 1) {
                                
this.applyFunction.init(getIterationRuntimeContext());
                        }

http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/GSAPageRank.java
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/GSAPageRank.java
 
b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/GSAPageRank.java
index 99624ca..324f9c3 100644
--- 
a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/GSAPageRank.java
+++ 
b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/GSAPageRank.java
@@ -25,6 +25,7 @@ import org.apache.flink.graph.Graph;
 import org.apache.flink.graph.GraphAlgorithm;
 import org.apache.flink.graph.Vertex;
 import org.apache.flink.graph.gsa.ApplyFunction;
+import org.apache.flink.graph.gsa.GSAConfiguration;
 import org.apache.flink.graph.gsa.GatherFunction;
 import org.apache.flink.graph.gsa.Neighbor;
 import org.apache.flink.graph.gsa.SumFunction;
@@ -32,22 +33,17 @@ import org.apache.flink.graph.gsa.SumFunction;
 /**
  * This is an implementation of a simple PageRank algorithm, using a 
gather-sum-apply iteration.
  * The user can define the damping factor and the maximum number of iterations.
- * If the number of vertices of the input graph is known, it should be 
provided as a parameter
- * to speed up computation. Otherwise, the algorithm will first execute a job 
to count the vertices.
- * 
+ *
  * The implementation assumes that each page has at least one incoming and one 
outgoing link.
  */
 public class GSAPageRank<K> implements GraphAlgorithm<K, Double, Double, 
DataSet<Vertex<K, Double>>> {
 
        private double beta;
        private int maxIterations;
-       private long numberOfVertices;
 
        /**
         * Creates an instance of the GSA PageRank algorithm.
-        * If the number of vertices of the input graph is known,
-        * use the {@link GSAPageRank#GSAPageRank(double, long, int)} 
constructor instead.
-        * 
+        *
         * The implementation assumes that each page has at least one incoming 
and one outgoing link.
         * 
         * @param beta the damping factor
@@ -58,37 +54,19 @@ public class GSAPageRank<K> implements GraphAlgorithm<K, 
Double, Double, DataSet
                this.maxIterations = maxIterations;
        }
 
-       /**
-        * Creates an instance of the GSA PageRank algorithm.
-        * If the number of vertices of the input graph is known,
-        * use the {@link GSAPageRank#GSAPageRank(double, int)} constructor 
instead.
-        * 
-        * The implementation assumes that each page has at least one incoming 
and one outgoing link.
-        * 
-        * @param beta the damping factor
-        * @param maxIterations the maximum number of iterations
-        * @param numVertices the number of vertices in the input
-        */
-       public GSAPageRank(double beta, long numVertices, int maxIterations) {
-               this.beta = beta;
-               this.numberOfVertices = numVertices;
-               this.maxIterations = maxIterations;
-       }
-
        @Override
        public DataSet<Vertex<K, Double>> run(Graph<K, Double, Double> network) 
throws Exception {
 
-               if (numberOfVertices == 0) {
-                       numberOfVertices = network.numberOfVertices();
-               }
-
                DataSet<Tuple2<K, Long>> vertexOutDegrees = 
network.outDegrees();
 
                Graph<K, Double, Double> networkWithWeights = network
                                .joinWithEdgesOnSource(vertexOutDegrees, new 
InitWeights());
 
-               return networkWithWeights.runGatherSumApplyIteration(new 
GatherRanks(numberOfVertices), new SumRanks(),
-                               new UpdateRanks<K>(beta, numberOfVertices), 
maxIterations)
+               GSAConfiguration parameters = new GSAConfiguration();
+               parameters.setOptNumVertices(true);
+
+               return networkWithWeights.runGatherSumApplyIteration(new 
GatherRanks(), new SumRanks(),
+                               new UpdateRanks<K>(beta), maxIterations, 
parameters)
                                .getVertices();
        }
 
@@ -99,18 +77,12 @@ public class GSAPageRank<K> implements GraphAlgorithm<K, 
Double, Double, DataSet
        @SuppressWarnings("serial")
        private static final class GatherRanks extends GatherFunction<Double, 
Double, Double> {
 
-               long numberOfVertices;
-
-               public GatherRanks(long numberOfVertices) {
-                       this.numberOfVertices = numberOfVertices;
-               }
-
                @Override
                public Double gather(Neighbor<Double, Double> neighbor) {
                        double neighborRank = neighbor.getNeighborValue();
 
                        if(getSuperstepNumber() == 1) {
-                               neighborRank = 1.0 / numberOfVertices;
+                               neighborRank = 1.0 / this.getNumberOfVertices();
                        }
 
                        return neighborRank * neighbor.getEdgeValue();
@@ -130,16 +102,14 @@ public class GSAPageRank<K> implements GraphAlgorithm<K, 
Double, Double, DataSet
        private static final class UpdateRanks<K> extends ApplyFunction<K, 
Double, Double> {
 
                private final double beta;
-               private final long numVertices;
 
-               public UpdateRanks(double beta, long numberOfVertices) {
+               public UpdateRanks(double beta) {
                        this.beta = beta;
-                       this.numVertices = numberOfVertices;
                }
 
                @Override
                public void apply(Double rankSum, Double currentValue) {
-                       setResult((1-beta)/numVertices + beta * rankSum);
+                       setResult((1-beta)/this.getNumberOfVertices() + beta * 
rankSum);
                }
        }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/HITSAlgorithm.java
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/HITSAlgorithm.java
 
b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/HITSAlgorithm.java
index 1ea367e..39e9487 100644
--- 
a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/HITSAlgorithm.java
+++ 
b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/HITSAlgorithm.java
@@ -42,8 +42,6 @@ import org.apache.flink.util.Preconditions;
  * represented a page that is linked by many different hubs.
  * Each vertex has a value of Tuple2 type, the first field is hub score and 
the second field is authority score.
  * The implementation sets same score to every vertex and adds the reverse 
edge to every edge at the beginning. 
- * If the number of vertices of the input graph is known, it should be 
provided as a parameter
- * to speed up computation. Otherwise, the algorithm will first execute a job 
to count the vertices.
  * <p>
  *
  * @see <a href="https://en.wikipedia.org/wiki/HITS_algorithm";>HITS 
Algorithm</a>
@@ -54,7 +52,6 @@ public class HITSAlgorithm<K, VV, EV> implements 
GraphAlgorithm<K, VV, EV, DataS
        private final static double MINIMUMTHRESHOLD = 1e-9;
 
        private int maxIterations;
-       private long numberOfVertices;
        private double convergeThreshold;
 
        /**
@@ -76,26 +73,6 @@ public class HITSAlgorithm<K, VV, EV> implements 
GraphAlgorithm<K, VV, EV, DataS
        }
 
        /**
-        * Create an instance of HITS algorithm.
-        * 
-        * @param maxIterations    the maximum number of iterations
-        * @param numberOfVertices the number of vertices in the graph
-        */
-       public HITSAlgorithm(int maxIterations, long numberOfVertices) {
-               this(maxIterations, MINIMUMTHRESHOLD, numberOfVertices);
-       }
-
-       /**
-        * Create an instance of HITS algorithm.
-        * 
-        * @param convergeThreshold convergence threshold for sum of scores to 
control whether the iteration should be stopped
-        * @param numberOfVertices  the number of vertices in the graph
-        */
-       public HITSAlgorithm(double convergeThreshold, long numberOfVertices) {
-               this(MAXIMUMITERATION, convergeThreshold, numberOfVertices);
-       }
-
-       /**
         * Creates an instance of HITS algorithm.
         *
         * @param maxIterations     the maximum number of iterations
@@ -108,26 +85,8 @@ public class HITSAlgorithm<K, VV, EV> implements 
GraphAlgorithm<K, VV, EV, DataS
                this.convergeThreshold = convergeThreshold;
        }
 
-       /**
-        * Creates an instance of HITS algorithm.
-        *
-        * @param maxIterations     the maximum number of iterations
-        * @param convergeThreshold convergence threshold for sum of scores to 
control whether the iteration should be stopped
-        * @param numberOfVertices  the number of vertices in the graph
-        */
-       public HITSAlgorithm(int maxIterations, double convergeThreshold, long 
numberOfVertices) {
-               this(maxIterations, convergeThreshold);
-               Preconditions.checkArgument(numberOfVertices > 0, "Number of 
vertices must be greater than zero.");
-               this.numberOfVertices = numberOfVertices;
-       }
-
        @Override
        public DataSet<Vertex<K, Tuple2<DoubleValue, DoubleValue>>> 
run(Graph<K, VV, EV> graph) throws Exception {
-
-               if (numberOfVertices == 0) {
-                       numberOfVertices = graph.numberOfVertices();
-               }
-
                Graph<K, Tuple2<DoubleValue, DoubleValue>, Boolean> newGraph = 
graph
                                .mapEdges(new AuthorityEdgeMapper<K, EV>())
                                .union(graph.reverse().mapEdges(new 
HubEdgeMapper<K, EV>()))
@@ -135,12 +94,13 @@ public class HITSAlgorithm<K, VV, EV> implements 
GraphAlgorithm<K, VV, EV, DataS
 
                ScatterGatherConfiguration parameter = new 
ScatterGatherConfiguration();
                parameter.setDirection(EdgeDirection.OUT);
+               parameter.setOptNumVertices(true);
                parameter.registerAggregator("updatedValueSum", new 
DoubleSumAggregator());
                parameter.registerAggregator("authorityValueSum", new 
DoubleSumAggregator());
                parameter.registerAggregator("diffValueSum", new 
DoubleSumAggregator());
 
                return newGraph
-                               .runScatterGatherIteration(new 
VertexUpdate<K>(maxIterations, convergeThreshold, numberOfVertices),
+                               .runScatterGatherIteration(new 
VertexUpdate<K>(maxIterations, convergeThreshold),
                                                new 
MessageUpdate<K>(maxIterations), maxIterations, parameter)
                                .getVertices();
        }
@@ -153,15 +113,13 @@ public class HITSAlgorithm<K, VV, EV> implements 
GraphAlgorithm<K, VV, EV, DataS
        public static final class VertexUpdate<K> extends 
VertexUpdateFunction<K, Tuple2<DoubleValue, DoubleValue>, Double> {
                private int maxIteration;
                private double convergeThreshold;
-               private long numberOfVertices;
                private DoubleSumAggregator updatedValueSumAggregator;
                private DoubleSumAggregator authoritySumAggregator;
                private DoubleSumAggregator diffSumAggregator;
 
-               public VertexUpdate(int maxIteration, double convergeThreshold, 
long numberOfVertices) {
+               public VertexUpdate(int maxIteration, double convergeThreshold) 
{
                        this.maxIteration = maxIteration;
                        this.convergeThreshold = convergeThreshold;
-                       this.numberOfVertices = numberOfVertices;
                }
 
                @Override
@@ -198,9 +156,9 @@ public class HITSAlgorithm<K, VV, EV> implements 
GraphAlgorithm<K, VV, EV, DataS
 
                                        //in the first iteration, the diff is 
the authority value of each vertex
                                        double previousAuthAverage = 1.0;
-                                       double diffValueSum = 1.0 * 
numberOfVertices;
+                                       double diffValueSum = 1.0 * 
getNumberOfVertices();
                                        if (getSuperstepNumber() > 1) {
-                                               previousAuthAverage = 
((DoubleValue) getPreviousIterationAggregate("authorityValueSum")).getValue() / 
numberOfVertices;
+                                               previousAuthAverage = 
((DoubleValue) getPreviousIterationAggregate("authorityValueSum")).getValue() / 
getNumberOfVertices();
                                                diffValueSum = ((DoubleValue) 
getPreviousIterationAggregate("diffValueSum")).getValue();
                                        }
                                        
authoritySumAggregator.aggregate(previousAuthAverage);
@@ -218,7 +176,7 @@ public class HITSAlgorithm<K, VV, EV> implements 
GraphAlgorithm<K, VV, EV, DataS
                                        newHubValue.setValue(updateValue);
                                        
newAuthorityValue.setValue(newAuthorityValue.getValue() / iterationValueSum);
                                        
authoritySumAggregator.aggregate(newAuthorityValue.getValue());
-                                       double previousAuthAverage = 
((DoubleValue) getPreviousIterationAggregate("authorityValueSum")).getValue() / 
numberOfVertices;
+                                       double previousAuthAverage = 
((DoubleValue) getPreviousIterationAggregate("authorityValueSum")).getValue() / 
getNumberOfVertices();
 
                                        // count the diff value of sum of 
authority scores
                                        
diffSumAggregator.aggregate((previousAuthAverage - 
newAuthorityValue.getValue()));

http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/PageRank.java
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/PageRank.java
 
b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/PageRank.java
index 9890a7c..f83b05b 100644
--- 
a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/PageRank.java
+++ 
b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/PageRank.java
@@ -27,27 +27,23 @@ import org.apache.flink.graph.GraphAlgorithm;
 import org.apache.flink.graph.Vertex;
 import org.apache.flink.graph.spargel.MessageIterator;
 import org.apache.flink.graph.spargel.MessagingFunction;
+import org.apache.flink.graph.spargel.ScatterGatherConfiguration;
 import org.apache.flink.graph.spargel.VertexUpdateFunction;
 
 /**
  * This is an implementation of a simple PageRank algorithm, using a 
scatter-gather iteration.
  * The user can define the damping factor and the maximum number of iterations.
- * If the number of vertices of the input graph is known, it should be 
provided as a parameter
- * to speed up computation. Otherwise, the algorithm will first execute a job 
to count the vertices.
- * 
+ *
  * The implementation assumes that each page has at least one incoming and one 
outgoing link.
  */
 public class PageRank<K> implements GraphAlgorithm<K, Double, Double, 
DataSet<Vertex<K, Double>>> {
 
        private double beta;
        private int maxIterations;
-       private long numberOfVertices;
 
        /**
         * Creates an instance of the PageRank algorithm.
-        * If the number of vertices of the input graph is known,
-        * use the {@link PageRank#PageRank(double, long, int)} constructor 
instead.
-        * 
+        *
         * The implementation assumes that each page has at least one incoming 
and one outgoing link.
         * 
         * @param beta the damping factor
@@ -56,40 +52,21 @@ public class PageRank<K> implements GraphAlgorithm<K, 
Double, Double, DataSet<Ve
        public PageRank(double beta, int maxIterations) {
                this.beta = beta;
                this.maxIterations = maxIterations;
-               this.numberOfVertices = 0;
-       }
-
-       /**
-        * Creates an instance of the PageRank algorithm.
-        * If the number of vertices of the input graph is known,
-        * use the {@link PageRank#PageRank(double, int)} constructor instead.
-        * 
-        * The implementation assumes that each page has at least one incoming 
and one outgoing link.
-        * 
-        * @param beta the damping factor
-        * @param maxIterations the maximum number of iterations
-        * @param numVertices the number of vertices in the input
-        */
-       public PageRank(double beta, long numVertices, int maxIterations) {
-               this.beta = beta;
-               this.maxIterations = maxIterations;
-               this.numberOfVertices = numVertices;
        }
 
        @Override
        public DataSet<Vertex<K, Double>> run(Graph<K, Double, Double> network) 
throws Exception {
 
-               if (numberOfVertices == 0) {
-                       numberOfVertices = network.numberOfVertices();
-               }
-
                DataSet<Tuple2<K, Long>> vertexOutDegrees = 
network.outDegrees();
 
                Graph<K, Double, Double> networkWithWeights = network
                                .joinWithEdgesOnSource(vertexOutDegrees, new 
InitWeights());
 
-               return networkWithWeights.runScatterGatherIteration(new 
VertexRankUpdater<K>(beta, numberOfVertices),
-                               new RankMessenger<K>(numberOfVertices), 
maxIterations)
+               ScatterGatherConfiguration parameters = new 
ScatterGatherConfiguration();
+               parameters.setOptNumVertices(true);
+
+               return networkWithWeights.runScatterGatherIteration(new 
VertexRankUpdater<K>(beta),
+                               new RankMessenger<K>(), maxIterations, 
parameters)
                                .getVertices();
        }
 
@@ -101,11 +78,9 @@ public class PageRank<K> implements GraphAlgorithm<K, 
Double, Double, DataSet<Ve
        public static final class VertexRankUpdater<K> extends 
VertexUpdateFunction<K, Double, Double> {
 
                private final double beta;
-               private final long numVertices;
-               
-               public VertexRankUpdater(double beta, long numberOfVertices) {
+
+               public VertexRankUpdater(double beta) {
                        this.beta = beta;
-                       this.numVertices = numberOfVertices;
                }
 
                @Override
@@ -116,7 +91,7 @@ public class PageRank<K> implements GraphAlgorithm<K, 
Double, Double, DataSet<Ve
                        }
 
                        // apply the dampening factor / random jump
-                       double newRank = (beta * rankSum) + (1 - beta) / 
numVertices;
+                       double newRank = (beta * rankSum) + (1 - beta) / 
this.getNumberOfVertices();
                        setNewVertexValue(newRank);
                }
        }
@@ -129,17 +104,11 @@ public class PageRank<K> implements GraphAlgorithm<K, 
Double, Double, DataSet<Ve
        @SuppressWarnings("serial")
        public static final class RankMessenger<K> extends MessagingFunction<K, 
Double, Double, Double> {
 
-               private final long numVertices;
-
-               public RankMessenger(long numberOfVertices) {
-                       this.numVertices = numberOfVertices;
-               }
-
                @Override
                public void sendMessages(Vertex<K, Double> vertex) {
                        if (getSuperstepNumber() == 1) {
                                // initialize vertex ranks
-                               vertex.setValue(new Double(1.0 / numVertices));
+                               vertex.setValue(1.0 / 
this.getNumberOfVertices());
                        }
 
                        for (Edge<K, Double> edge : getEdges()) {

http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/spargel/ScatterGatherIteration.java
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/spargel/ScatterGatherIteration.java
 
b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/spargel/ScatterGatherIteration.java
index 496e36d..165ef1e 100644
--- 
a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/spargel/ScatterGatherIteration.java
+++ 
b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/spargel/ScatterGatherIteration.java
@@ -18,18 +18,15 @@
 
 package org.apache.flink.graph.spargel;
 
-import java.util.Iterator;
-import java.util.Map;
-
 import org.apache.flink.api.common.aggregators.Aggregator;
 import org.apache.flink.api.common.functions.FlatJoinFunction;
 import org.apache.flink.api.common.functions.MapFunction;
-import org.apache.flink.api.java.DataSet;
-import org.apache.flink.api.java.operators.DeltaIteration;
 import org.apache.flink.api.common.functions.RichCoGroupFunction;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.DataSet;
 import org.apache.flink.api.java.operators.CoGroupOperator;
 import org.apache.flink.api.java.operators.CustomUnaryOperation;
+import org.apache.flink.api.java.operators.DeltaIteration;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.tuple.Tuple3;
 import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
@@ -40,9 +37,15 @@ import org.apache.flink.graph.Edge;
 import org.apache.flink.graph.EdgeDirection;
 import org.apache.flink.graph.Graph;
 import org.apache.flink.graph.Vertex;
+import org.apache.flink.graph.utils.GraphUtils;
+import org.apache.flink.types.LongValue;
 import org.apache.flink.util.Collector;
 import org.apache.flink.util.Preconditions;
 
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Map;
+
 /**
  * This class represents iterative graph computations, programmed in a 
scatter-gather perspective.
  * It is a special case of <i>Bulk Synchronous Parallel</i> computation.
@@ -151,11 +154,10 @@ public class ScatterGatherIteration<K, VV, Message, EV>
                // check whether the numVertices option is set and, if so, 
compute the total number of vertices
                // and set it within the messaging and update functions
 
+               DataSet<LongValue> numberOfVertices = null;
                if (this.configuration != null && 
this.configuration.isOptNumVertices()) {
                        try {
-                               long numberOfVertices = 
graph.numberOfVertices();
-                               
messagingFunction.setNumberOfVertices(numberOfVertices);
-                               
updateFunction.setNumberOfVertices(numberOfVertices);
+                               numberOfVertices = 
GraphUtils.count(this.initialVertices);
                        } catch (Exception e) {
                                e.printStackTrace();
                        }
@@ -173,9 +175,9 @@ public class ScatterGatherIteration<K, VV, Message, EV>
                // check whether the degrees option is set and, if so, compute 
the in and the out degrees and
                // add them to the vertex value
                if(this.configuration != null && 
this.configuration.isOptDegrees()) {
-                       return createResultVerticesWithDegrees(graph, 
messagingDirection, messageTypeInfo);
+                       return createResultVerticesWithDegrees(graph, 
messagingDirection, messageTypeInfo, numberOfVertices);
                } else {
-                       return createResultSimpleVertex(messagingDirection, 
messageTypeInfo);
+                       return createResultSimpleVertex(messagingDirection, 
messageTypeInfo, numberOfVertices);
                }
        }
 
@@ -246,6 +248,10 @@ public class ScatterGatherIteration<K, VV, Message, EV>
 
                @Override
                public void open(Configuration parameters) throws Exception {
+                       if (getRuntimeContext().hasBroadcastVariable("number of 
vertices")) {
+                               Collection<LongValue> numberOfVertices = 
getRuntimeContext().getBroadcastVariable("number of vertices");
+                               
this.vertexUpdateFunction.setNumberOfVertices(numberOfVertices.iterator().next().getValue());
+                       }
                        if (getIterationRuntimeContext().getSuperstepNumber() 
== 1) {
                                
this.vertexUpdateFunction.init(getIterationRuntimeContext());
                        }
@@ -368,10 +374,13 @@ public class ScatterGatherIteration<K, VV, Message, EV>
                
                @Override
                public void open(Configuration parameters) throws Exception {
+                       if (getRuntimeContext().hasBroadcastVariable("number of 
vertices")) {
+                               Collection<LongValue> numberOfVertices = 
getRuntimeContext().getBroadcastVariable("number of vertices");
+                               
this.messagingFunction.setNumberOfVertices(numberOfVertices.iterator().next().getValue());
+                       }
                        if (getIterationRuntimeContext().getSuperstepNumber() 
== 1) {
                                
this.messagingFunction.init(getIterationRuntimeContext());
                        }
-                       
                        this.messagingFunction.preSuperstep();
                }
                
@@ -459,7 +468,8 @@ public class ScatterGatherIteration<K, VV, Message, EV>
         */
        private CoGroupOperator<?, ?, Tuple2<K, Message>> 
buildMessagingFunction(
                        DeltaIteration<Vertex<K, VV>, Vertex<K, VV>> iteration,
-                       TypeInformation<Tuple2<K, Message>> messageTypeInfo, 
int whereArg, int equalToArg) {
+                       TypeInformation<Tuple2<K, Message>> messageTypeInfo, 
int whereArg, int equalToArg,
+                       DataSet<LongValue> numberOfVertices) {
 
                // build the messaging function (co group)
                CoGroupOperator<?, ?, Tuple2<K, Message>> messages;
@@ -475,6 +485,9 @@ public class ScatterGatherIteration<K, VV, Message, EV>
                        for (Tuple2<String, DataSet<?>> e : 
this.configuration.getMessagingBcastVars()) {
                                messages = messages.withBroadcastSet(e.f1, 
e.f0);
                        }
+                       if (this.configuration.isOptNumVertices()) {
+                               messages = 
messages.withBroadcastSet(numberOfVertices, "number of vertices");
+                       }
                }
 
                return messages;
@@ -493,7 +506,8 @@ public class ScatterGatherIteration<K, VV, Message, EV>
         */
        private CoGroupOperator<?, ?, Tuple2<K, Message>> 
buildMessagingFunctionVerticesWithDegrees(
                        DeltaIteration<Vertex<K, Tuple3<VV, Long, Long>>, 
Vertex<K, Tuple3<VV, Long, Long>>> iteration,
-                       TypeInformation<Tuple2<K, Message>> messageTypeInfo, 
int whereArg, int equalToArg) {
+                       TypeInformation<Tuple2<K, Message>> messageTypeInfo, 
int whereArg, int equalToArg,
+                       DataSet<LongValue> numberOfVertices) {
 
                // build the messaging function (co group)
                CoGroupOperator<?, ?, Tuple2<K, Message>> messages;
@@ -510,6 +524,9 @@ public class ScatterGatherIteration<K, VV, Message, EV>
                        for (Tuple2<String, DataSet<?>> e : 
this.configuration.getMessagingBcastVars()) {
                                messages = messages.withBroadcastSet(e.f1, 
e.f0);
                        }
+                       if (this.configuration.isOptNumVertices()) {
+                               messages = 
messages.withBroadcastSet(numberOfVertices, "number of vertices");
+                       }
                }
 
                return messages;
@@ -546,10 +563,11 @@ public class ScatterGatherIteration<K, VV, Message, EV>
         *
         * @param messagingDirection
         * @param messageTypeInfo
+        * @param numberOfVertices
         * @return the operator
         */
        private DataSet<Vertex<K, VV>> createResultSimpleVertex(EdgeDirection 
messagingDirection,
-               TypeInformation<Tuple2<K, Message>> messageTypeInfo) {
+               TypeInformation<Tuple2<K, Message>> messageTypeInfo, 
DataSet<LongValue> numberOfVertices) {
 
                DataSet<Tuple2<K, Message>> messages;
 
@@ -561,14 +579,14 @@ public class ScatterGatherIteration<K, VV, Message, EV>
 
                switch (messagingDirection) {
                        case IN:
-                               messages = buildMessagingFunction(iteration, 
messageTypeInfo, 1, 0);
+                               messages = buildMessagingFunction(iteration, 
messageTypeInfo, 1, 0, numberOfVertices);
                                break;
                        case OUT:
-                               messages = buildMessagingFunction(iteration, 
messageTypeInfo, 0, 0);
+                               messages = buildMessagingFunction(iteration, 
messageTypeInfo, 0, 0, numberOfVertices);
                                break;
                        case ALL:
-                               messages = buildMessagingFunction(iteration, 
messageTypeInfo, 1, 0)
-                                               
.union(buildMessagingFunction(iteration, messageTypeInfo, 0, 0)) ;
+                               messages = buildMessagingFunction(iteration, 
messageTypeInfo, 1, 0, numberOfVertices)
+                                               
.union(buildMessagingFunction(iteration, messageTypeInfo, 0, 0, 
numberOfVertices)) ;
                                break;
                        default:
                                throw new IllegalArgumentException("Illegal 
edge direction");
@@ -581,6 +599,10 @@ public class ScatterGatherIteration<K, VV, Message, EV>
                CoGroupOperator<?, ?, Vertex<K, VV>> updates =
                                
messages.coGroup(iteration.getSolutionSet()).where(0).equalTo(0).with(updateUdf);
 
+               if (this.configuration != null && 
this.configuration.isOptNumVertices()) {
+                       updates = updates.withBroadcastSet(numberOfVertices, 
"number of vertices");
+               }
+
                configureUpdateFunction(updates);
 
                return iteration.closeWith(updates, updates);
@@ -593,11 +615,12 @@ public class ScatterGatherIteration<K, VV, Message, EV>
         * @param graph
         * @param messagingDirection
         * @param messageTypeInfo
+        * @param numberOfVertices
         * @return the operator
         */
        @SuppressWarnings("serial")
        private DataSet<Vertex<K, VV>> createResultVerticesWithDegrees(Graph<K, 
VV, EV> graph, EdgeDirection messagingDirection,
-                       TypeInformation<Tuple2<K, Message>> messageTypeInfo) {
+                       TypeInformation<Tuple2<K, Message>> messageTypeInfo, 
DataSet<LongValue> numberOfVertices) {
 
                DataSet<Tuple2<K, Message>> messages;
 
@@ -636,14 +659,14 @@ public class ScatterGatherIteration<K, VV, Message, EV>
 
                switch (messagingDirection) {
                        case IN:
-                               messages = 
buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 1, 0);
+                               messages = 
buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 1, 0, 
numberOfVertices);
                                break;
                        case OUT:
-                               messages = 
buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 0, 0);
+                               messages = 
buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 0, 0, 
numberOfVertices);
                                break;
                        case ALL:
-                               messages = 
buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 1, 0)
-                                               
.union(buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 0, 
0)) ;
+                               messages = 
buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 1, 0, 
numberOfVertices)
+                                               
.union(buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 0, 
0, numberOfVertices)) ;
                                break;
                        default:
                                throw new IllegalArgumentException("Illegal 
edge direction");
@@ -657,6 +680,10 @@ public class ScatterGatherIteration<K, VV, Message, EV>
                CoGroupOperator<?, ?, Vertex<K, Tuple3<VV, Long, Long>>> 
updates =
                                
messages.coGroup(iteration.getSolutionSet()).where(0).equalTo(0).with(updateUdf);
 
+               if (this.configuration != null && 
this.configuration.isOptNumVertices()) {
+                       updates = updates.withBroadcastSet(numberOfVertices, 
"number of vertices");
+               }
+
                configureUpdateFunction(updates);
 
                return iteration.closeWith(updates, updates).map(

http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/utils/GraphUtils.java
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/utils/GraphUtils.java
 
b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/utils/GraphUtils.java
index 009d791..264479b 100644
--- 
a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/utils/GraphUtils.java
+++ 
b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/utils/GraphUtils.java
@@ -19,12 +19,18 @@
 package org.apache.flink.graph.utils;
 
 import org.apache.flink.api.common.JobExecutionResult;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.java.DataSet;
 import org.apache.flink.api.java.Utils;
 import org.apache.flink.graph.Edge;
 import org.apache.flink.graph.Graph;
 import org.apache.flink.graph.Vertex;
+import org.apache.flink.types.LongValue;
 import org.apache.flink.util.AbstractID;
 
+import static 
org.apache.flink.api.java.typeutils.ValueTypeInfo.LONG_VALUE_TYPE_INFO;
+
 public class GraphUtils {
 
        /**
@@ -50,4 +56,56 @@ public class GraphUtils {
 
                return checksum;
        }
+
+       /**
+        * Count the number of elements in a DataSet.
+        *
+        * @param input DataSet of elements to be counted
+        * @param <T> element type
+        * @return count
+        */
+       public static <T> DataSet<LongValue> count(DataSet<T> input) {
+               return input
+                       .map(new MapTo<T, LongValue>(new LongValue(1)))
+                               .returns(LONG_VALUE_TYPE_INFO)
+                       .reduce(new AddLongValue());
+       }
+
+       /**
+        * Map each element to a value.
+        *
+        * @param <I> input type
+        * @param <O> output type
+        */
+       public static class MapTo<I, O>
+       implements MapFunction<I, O> {
+               private final O value;
+
+               /**
+                * Map each element to the given object.
+                *
+                * @param value the object to emit for each element
+                */
+               public MapTo(O value) {
+                       this.value = value;
+               }
+
+               @Override
+               public O map(I o) throws Exception {
+                       return value;
+               }
+       }
+
+       /**
+        * Add {@link LongValue} elements.
+        */
+       public static class AddLongValue
+       implements ReduceFunction<LongValue> {
+               @Override
+               public LongValue reduce(LongValue value1, LongValue value2)
+                               throws Exception {
+                       value1.setValue(value1.getValue() + value2.getValue());
+                       return value1;
+               }
+       }
 }

Reply via email to