[FLINK-1201] [gelly] reduceOnNeighbors without vertex value

Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/d9b46c6e
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/d9b46c6e
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/d9b46c6e

Branch: refs/heads/master
Commit: d9b46c6eb0fec25161bf8485bc954a3c726ca1c8
Parents: 9a9dfc7
Author: vasia <vasilikikala...@gmail.com>
Authored: Sun Dec 21 16:50:52 2014 +0100
Committer: Stephan Ewen <se...@apache.org>
Committed: Wed Feb 11 10:46:13 2015 +0100

----------------------------------------------------------------------
 .../main/java/org/apache/flink/graph/Graph.java | 100 +++++++++++----
 .../apache/flink/graph/NeighborsFunction.java   |   2 +-
 .../graph/test/TestReduceOnNeighborMethods.java | 125 ++++++++++++++++++-
 3 files changed, 201 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/d9b46c6e/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/Graph.java
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/Graph.java 
b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/Graph.java
index b5cb2ac..bc91b99 100644
--- a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/Graph.java
+++ b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/Graph.java
@@ -379,7 +379,7 @@ public class Graph<K extends Comparable<K> & Serializable, 
VV extends Serializab
                                        .groupBy(0).reduceGroup(new 
ApplyGroupReduceFunction<K, EV, T>(edgesFunction));
                case ALL:
                        return edges.flatMap(new EmitOneEdgePerNode<K, VV, 
EV>()).groupBy(0)
-                                       .reduceGroup(new 
ApplyGroupReduceFunctionOnAllEdges<K, EV, T>(edgesFunction));
+                                       .reduceGroup(new 
ApplyGroupReduceFunction<K, EV, T>(edgesFunction));
                default:
                        throw new IllegalArgumentException("Illegal edge 
direction");
                }
@@ -400,27 +400,6 @@ public class Graph<K extends Comparable<K> & Serializable, 
VV extends Serializab
                }
        }
 
-       private static final class ApplyGroupReduceFunctionOnAllEdges<K extends 
Comparable<K> & Serializable, 
-               EV extends Serializable, T> implements 
GroupReduceFunction<Tuple2<K, Edge<K, EV>>, T>,
-               ResultTypeQueryable<T> {
-       
-               private EdgesFunction<K, EV, T> function;
-
-               public ApplyGroupReduceFunctionOnAllEdges(EdgesFunction<K, EV, 
T> fun) {
-                       this.function = fun;
-               }
-
-               public void reduce(final Iterable<Tuple2<K, Edge<K, EV>>> 
keysWithEdges,
-                               Collector<T> out) throws Exception {
-                       out.collect(function.iterateEdges(keysWithEdges));
-               }
-
-               @Override
-               public TypeInformation<T> getProducedType() {
-                       return 
TypeExtractor.createTypeInfo(EdgesFunction.class, function.getClass(), 2, null, 
null);
-               }
-       }
-
        private static final class ApplyGroupReduceFunction<K extends 
Comparable<K> & Serializable, 
                EV extends Serializable, T> implements 
GroupReduceFunction<Tuple2<K, Edge<K, EV>>, T>,
                ResultTypeQueryable<T> {
@@ -1023,6 +1002,83 @@ public class Graph<K extends Comparable<K> & 
Serializable, VV extends Serializab
                }
        }
 
+       /**
+        * Compute an aggregate over the neighbors (edges and vertices) of each 
vertex.
+        * The function applied on the neighbors only has access to the vertex 
id
+        * (not the vertex value).
+        * @param neighborsFunction the function to apply to the neighborhood
+        * @param direction the edge direction (in-, out-, all-)
+        * @param <T> the output type
+        * @return a dataset of a T
+        * @throws IllegalArgumentException
+        */
+       public <T> DataSet<T> reduceOnNeighbors(NeighborsFunction<K, VV, EV, T> 
neighborsFunction,
+                       EdgeDirection direction) throws 
IllegalArgumentException {
+               switch (direction) {
+               case IN:
+                       // create <edge-sourceVertex> pairs
+                       DataSet<Tuple3<K, Edge<K, EV>, Vertex<K, VV>>> 
edgesWithSources = edges.join(this.vertices)
+                               .where(0).equalTo(0).with(new 
ProjectVertexIdJoin<K, VV, EV>(1));
+                       return edgesWithSources.groupBy(0).reduceGroup(
+                                       new ApplyNeighborGroupReduceFunction<K, 
VV, EV, T>(neighborsFunction));
+               case OUT:
+                       // create <edge-targetVertex> pairs
+                       DataSet<Tuple3<K, Edge<K, EV>, Vertex<K, VV>>> 
edgesWithTargets = edges.join(this.vertices)
+                       .where(1).equalTo(0).with(new ProjectVertexIdJoin<K, 
VV, EV>(0));
+               return edgesWithTargets.groupBy(0).reduceGroup(
+                               new ApplyNeighborGroupReduceFunction<K, VV, EV, 
T>(neighborsFunction));
+               case ALL:
+                       // create <edge-sourceOrTargetVertex> pairs
+                       DataSet<Tuple3<K, Edge<K, EV>, Vertex<K, VV>>> 
edgesWithNeighbors = edges.flatMap(
+                                       new EmitOneEdgeWithNeighborPerNode<K, 
VV, EV>()).join(this.vertices)
+                                       .where(1).equalTo(0).with(new 
ProjectEdgeWithNeighbor<K, VV, EV>());
+
+                       return edgesWithNeighbors.groupBy(0).reduceGroup(
+                                       new ApplyNeighborGroupReduceFunction<K, 
VV, EV, T>(neighborsFunction));
+               default:
+                       throw new IllegalArgumentException("Illegal edge 
direction");
+               }
+       }
+
+       private static final class ApplyNeighborGroupReduceFunction<K extends 
Comparable<K> & Serializable, 
+               VV extends Serializable, EV extends Serializable, T> implements 
GroupReduceFunction<
+               Tuple3<K, Edge<K, EV>, Vertex<K, VV>>, T>,      
ResultTypeQueryable<T> {
+       
+               private NeighborsFunction<K, VV, EV, T> function;
+       
+               public ApplyNeighborGroupReduceFunction(NeighborsFunction<K, 
VV, EV, T> fun) {
+                       this.function = fun;
+               }
+       
+               public void reduce(Iterable<Tuple3<K, Edge<K, EV>, Vertex<K, 
VV>>> edges,
+                               Collector<T> out) throws Exception {
+                       out.collect(function.iterateNeighbors(edges));
+                       
+               }
+
+               @Override
+               public TypeInformation<T> getProducedType() {
+                       return 
TypeExtractor.createTypeInfo(NeighborsFunction.class, function.getClass(), 3, 
null, null);
+               }       
+       }
+
+       private static final class ProjectVertexIdJoin<K extends Comparable<K> 
& Serializable, 
+               VV extends Serializable, EV extends Serializable> implements 
FlatJoinFunction<Edge<K, EV>, 
+               Vertex<K, VV>, Tuple3<K, Edge<K, EV>, Vertex<K, VV>>> {
+
+               private int fieldPosition;
+
+               public ProjectVertexIdJoin(int position) {
+                       this.fieldPosition = position;
+               }
+               @SuppressWarnings("unchecked")
+               public void join(Edge<K, EV> edge, Vertex<K, VV> otherVertex,
+                               Collector<Tuple3<K, Edge<K, EV>, Vertex<K, 
VV>>> out) {
+                       out.collect(new Tuple3<K, Edge<K, EV>, Vertex<K, VV>>(
+                                       (K)edge.getField(fieldPosition), edge, 
otherVertex));
+               }
+       }
+
        private static final class ProjectEdgeWithNeighbor<K extends 
Comparable<K> & Serializable, 
                VV extends Serializable, EV extends Serializable> implements 
                FlatJoinFunction<Tuple3<K, K, Edge<K, EV>>, Vertex<K, VV>, 
Tuple3<K, Edge<K, EV>, Vertex<K, VV>>> {

http://git-wip-us.apache.org/repos/asf/flink/blob/d9b46c6e/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/NeighborsFunction.java
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/NeighborsFunction.java
 
b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/NeighborsFunction.java
index 63bc527..124aea0 100644
--- 
a/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/NeighborsFunction.java
+++ 
b/flink-staging/flink-gelly/src/main/java/org/apache/flink/graph/NeighborsFunction.java
@@ -8,5 +8,5 @@ import org.apache.flink.api.java.tuple.Tuple3;
 public interface NeighborsFunction<K extends Comparable<K> & Serializable, VV 
extends Serializable, 
        EV extends Serializable, O> extends Function, Serializable {
 
-       O iterateEdges(Iterable<Tuple3<K, Edge<K, EV>, Vertex<K, VV>>> 
neighbors) throws Exception;
+       O iterateNeighbors(Iterable<Tuple3<K, Edge<K, EV>, Vertex<K, VV>>> 
neighbors) throws Exception;
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/d9b46c6e/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/test/TestReduceOnNeighborMethods.java
----------------------------------------------------------------------
diff --git 
a/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/test/TestReduceOnNeighborMethods.java
 
b/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/test/TestReduceOnNeighborMethods.java
index 7a1413a..be0a867 100644
--- 
a/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/test/TestReduceOnNeighborMethods.java
+++ 
b/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/test/TestReduceOnNeighborMethods.java
@@ -3,11 +3,13 @@ package flink.graphs;
 import java.io.FileNotFoundException;
 import java.io.IOException;
 import java.util.Collection;
+import java.util.Iterator;
 import java.util.LinkedList;
 
 import org.apache.flink.api.java.DataSet;
 import org.apache.flink.api.java.ExecutionEnvironment;
 import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.test.util.JavaProgramTestBase;
 import org.junit.runner.RunWith;
@@ -17,7 +19,7 @@ import org.junit.runners.Parameterized.Parameters;
 @RunWith(Parameterized.class)
 public class TestReduceOnNeighborMethods extends JavaProgramTestBase {
 
-       private static int NUM_PROGRAMS = 3;
+       private static int NUM_PROGRAMS = 6;
        
        private int curProgId = config.getInteger("ProgramId", -1);
        private String resultPath;
@@ -95,7 +97,7 @@ public class TestReduceOnNeighborMethods extends 
JavaProgramTestBase {
                        case 2: {
                                /*
                                 * Get the sum of in-neighbor values
-                                * time edge weights for each vertex
+                                * times the edge weights for each vertex
                         */
                                final ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
                                Graph<Long, Long, Long> graph = 
Graph.create(TestGraphUtils.getLongLongVertexData(env), 
@@ -125,6 +127,7 @@ public class TestReduceOnNeighborMethods extends 
JavaProgramTestBase {
                        case 3: {
                                /*
                                 * Get the sum of all neighbor values
+                                * including own vertex value
                                 * for each vertex
                         */
                                final ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
@@ -140,7 +143,123 @@ public class TestReduceOnNeighborMethods extends 
JavaProgramTestBase {
                                                                for 
(Tuple2<Edge<Long, Long>, Vertex<Long, Long>> neighbor : neighbors) {
                                                                        sum += 
neighbor.f1.getValue();
                                                                }
-                                                               return new 
Tuple2<Long, Long>(vertex.getId(), sum);
+                                                               return new 
Tuple2<Long, Long>(vertex.getId(), sum + vertex.getValue());
+                                                       }
+                                               }, EdgeDirection.ALL);
+
+                               
verticesWithSumOfOutNeighborValues.writeAsCsv(resultPath);
+                               env.execute();
+                               return "1,11\n" +
+                                               "2,6\n" + 
+                                               "3,15\n" +
+                                               "4,12\n" + 
+                                               "5,13\n";
+                       }
+                       case 4: {
+                               /*
+                                * Get the sum of out-neighbor values
+                                * for each vertex
+                        */
+                               final ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+                               Graph<Long, Long, Long> graph = 
Graph.create(TestGraphUtils.getLongLongVertexData(env), 
+                                               
TestGraphUtils.getLongLongEdgeData(env), env);
+
+                               DataSet<Tuple2<Long, Long>> 
verticesWithSumOfOutNeighborValues = 
+                                               graph.reduceOnNeighbors(new 
NeighborsFunction<Long, Long, Long, 
+                                                               Tuple2<Long, 
Long>>() {
+                                                       public Tuple2<Long, 
Long> iterateNeighbors(
+                                                                       
Iterable<Tuple3<Long, Edge<Long, Long>, Vertex<Long, Long>>> neighbors) {
+                                                               long sum = 0;
+                                                               Tuple3<Long, 
Edge<Long, Long>, Vertex<Long, Long>> first = 
+                                                                               
new Tuple3<Long, Edge<Long, Long>, Vertex<Long, Long>>();
+                                                               
Iterator<Tuple3<Long, Edge<Long, Long>, Vertex<Long, Long>>> neighborsIterator 
= 
+                                                                               
neighbors.iterator();
+                                                               if 
(neighborsIterator.hasNext()) {
+                                                                       first = 
neighborsIterator.next();
+                                                                       sum = 
first.f2.getValue();
+                                                               }
+                                                               
while(neighborsIterator.hasNext()) {
+                                                                       sum += 
neighborsIterator.next().f2.getValue();
+                                                               }
+                                                               return new 
Tuple2<Long, Long>(first.f0, sum);
+                                                       }
+                                               }, EdgeDirection.OUT);
+
+                               
verticesWithSumOfOutNeighborValues.writeAsCsv(resultPath);
+                               env.execute();
+                               return "1,5\n" +
+                                               "2,3\n" + 
+                                               "3,9\n" +
+                                               "4,5\n" + 
+                                               "5,1\n";
+                       }
+                       case 5: {
+                               /*
+                                * Get the sum of in-neighbor values
+                                * times the edge weights for each vertex
+                        */
+                               final ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+                               Graph<Long, Long, Long> graph = 
Graph.create(TestGraphUtils.getLongLongVertexData(env), 
+                                               
TestGraphUtils.getLongLongEdgeData(env), env);
+
+                               DataSet<Tuple2<Long, Long>> verticesWithSum = 
+                                               graph.reduceOnNeighbors(new 
NeighborsFunction<Long, Long, Long, 
+                                                               Tuple2<Long, 
Long>>() {
+                                                       public Tuple2<Long, 
Long> iterateNeighbors(
+                                                                       
Iterable<Tuple3<Long, Edge<Long, Long>, Vertex<Long, Long>>> neighbors) {
+                                                               long sum = 0;
+                                                               Tuple3<Long, 
Edge<Long, Long>, Vertex<Long, Long>> first = 
+                                                                               
new Tuple3<Long, Edge<Long, Long>, Vertex<Long, Long>>();
+                                                               
Iterator<Tuple3<Long, Edge<Long, Long>, Vertex<Long, Long>>> neighborsIterator 
= 
+                                                                               
neighbors.iterator();
+                                                               if 
(neighborsIterator.hasNext()) {
+                                                                       first = 
neighborsIterator.next();
+                                                                       sum = 
first.f2.getValue() * first.f1.getValue();
+                                                               }
+                                                               
while(neighborsIterator.hasNext()) {
+                                                                       
Tuple3<Long, Edge<Long, Long>, Vertex<Long, Long>> next = 
neighborsIterator.next();
+                                                                       sum += 
next.f2.getValue() * next.f1.getValue();
+                                                               }
+                                                               return new 
Tuple2<Long, Long>(first.f0, sum);
+                                                       }
+                                               }, EdgeDirection.IN);
+
+
+                               verticesWithSum.writeAsCsv(resultPath);
+                               env.execute();
+                               return "1,255\n" +
+                                               "2,12\n" + 
+                                               "3,59\n" +
+                                               "4,102\n" + 
+                                               "5,285\n";
+                       }
+                       case 6: {
+                               /*
+                                * Get the sum of all neighbor values
+                                * for each vertex
+                        */
+                               final ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+                               Graph<Long, Long, Long> graph = 
Graph.create(TestGraphUtils.getLongLongVertexData(env), 
+                                               
TestGraphUtils.getLongLongEdgeData(env), env);
+
+                               DataSet<Tuple2<Long, Long>> 
verticesWithSumOfOutNeighborValues = 
+                                               graph.reduceOnNeighbors(new 
NeighborsFunction<Long, Long, Long, 
+                                                               Tuple2<Long, 
Long>>() {
+                                                       public Tuple2<Long, 
Long> iterateNeighbors(
+                                                                       
Iterable<Tuple3<Long, Edge<Long, Long>, Vertex<Long, Long>>> neighbors) {
+                                                               long sum = 0;
+                                                               Tuple3<Long, 
Edge<Long, Long>, Vertex<Long, Long>> first = 
+                                                                               
new Tuple3<Long, Edge<Long, Long>, Vertex<Long, Long>>();
+                                                               
Iterator<Tuple3<Long, Edge<Long, Long>, Vertex<Long, Long>>> neighborsIterator 
= 
+                                                                               
neighbors.iterator();
+                                                               if 
(neighborsIterator.hasNext()) {
+                                                                       first = 
neighborsIterator.next();
+                                                                       sum = 
first.f2.getValue();
+                                                               }
+                                                               
while(neighborsIterator.hasNext()) {
+                                                                       sum += 
neighborsIterator.next().f2.getValue();
+                                                               }
+                                                               return new 
Tuple2<Long, Long>(first.f0, sum);
                                                        }
                                                }, EdgeDirection.ALL);
 

Reply via email to