http://git-wip-us.apache.org/repos/asf/flink/blob/cc1a7979/flink-libraries/flink-gelly/src/test/java/org/apache/flink/graph/pregel/PregelCompilerTest.java
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-gelly/src/test/java/org/apache/flink/graph/pregel/PregelCompilerTest.java
 
b/flink-libraries/flink-gelly/src/test/java/org/apache/flink/graph/pregel/PregelCompilerTest.java
new file mode 100644
index 0000000..04f0ca4
--- /dev/null
+++ 
b/flink-libraries/flink-gelly/src/test/java/org/apache/flink/graph/pregel/PregelCompilerTest.java
@@ -0,0 +1,306 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.graph.pregel;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import org.apache.flink.api.common.Plan;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.operators.util.FieldList;
+import org.apache.flink.api.java.io.DiscardingOutputFormat;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.optimizer.util.CompilerTestBase;
+import org.junit.Test;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.graph.Edge;
+import org.apache.flink.graph.Graph;
+import org.apache.flink.graph.Vertex;
+import org.apache.flink.optimizer.plan.DualInputPlanNode;
+import org.apache.flink.optimizer.plan.OptimizedPlan;
+import org.apache.flink.optimizer.plan.PlanNode;
+import org.apache.flink.optimizer.plan.SingleInputPlanNode;
+import org.apache.flink.optimizer.plan.SinkPlanNode;
+import org.apache.flink.optimizer.plan.WorksetIterationPlanNode;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+import org.apache.flink.types.NullValue;
+import org.apache.flink.graph.utils.Tuple2ToVertexMap;
+
+public class PregelCompilerTest extends CompilerTestBase {
+
+       private static final long serialVersionUID = 1L;
+
+       @SuppressWarnings("serial")
+       @Test
+       public void testPregelCompiler() {
+               try {
+                       ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+                       env.setParallelism(DEFAULT_PARALLELISM);
+                       // compose test program
+                       {
+
+                               DataSet<Vertex<Long, Long>> initialVertices = 
env.fromElements(
+                                               new Tuple2<>(1L, 1L), new 
Tuple2<>(2L, 2L))
+                                               .map(new 
Tuple2ToVertexMap<Long, Long>());
+
+                               DataSet<Edge<Long, NullValue>> edges = 
env.fromElements(new Tuple2<>(1L, 2L))
+                                       .map(new MapFunction<Tuple2<Long,Long>, 
Edge<Long, NullValue>>() {
+
+                                               public Edge<Long, NullValue> 
map(Tuple2<Long, Long> edge) {
+                                                       return new 
Edge<>(edge.f0, edge.f1, NullValue.getInstance());
+                                               }
+                               });
+
+                               Graph<Long, Long, NullValue> graph = 
Graph.fromDataSet(initialVertices, edges, env);
+                               
+                               DataSet<Vertex<Long, Long>> result = 
graph.runVertexCentricIteration(
+                                               new CCCompute(), null, 
100).getVertices();
+                               
+                               result.output(new 
DiscardingOutputFormat<Vertex<Long, Long>>());
+                       }
+                       
+                       Plan p = env.createProgramPlan("Pregel Connected 
Components");
+                       OptimizedPlan op = compileNoStats(p);
+                       
+                       // check the sink
+                       SinkPlanNode sink = op.getDataSinks().iterator().next();
+                       assertEquals(ShipStrategyType.FORWARD, 
sink.getInput().getShipStrategy());
+                       assertEquals(DEFAULT_PARALLELISM, 
sink.getParallelism());
+                       
+                       // check the iteration
+                       WorksetIterationPlanNode iteration = 
(WorksetIterationPlanNode) sink.getInput().getSource();
+                       assertEquals(DEFAULT_PARALLELISM, 
iteration.getParallelism());
+                       
+                       // check the solution set delta
+                       PlanNode ssDelta = 
iteration.getSolutionSetDeltaPlanNode();
+                       assertTrue(ssDelta instanceof SingleInputPlanNode);
+                       
+                       SingleInputPlanNode ssFlatMap = (SingleInputPlanNode) 
((SingleInputPlanNode) (ssDelta)).getInput().getSource();
+                       assertEquals(DEFAULT_PARALLELISM, 
ssFlatMap.getParallelism());
+                       assertEquals(ShipStrategyType.FORWARD, 
ssFlatMap.getInput().getShipStrategy());
+                       
+                       // check the computation coGroup
+                       DualInputPlanNode computationCoGroup = 
(DualInputPlanNode) (ssFlatMap.getInput().getSource());
+                       assertEquals(DEFAULT_PARALLELISM, 
computationCoGroup.getParallelism());
+                       assertEquals(ShipStrategyType.FORWARD, 
computationCoGroup.getInput1().getShipStrategy());
+                       assertEquals(ShipStrategyType.PARTITION_HASH, 
computationCoGroup.getInput2().getShipStrategy());
+                       
assertTrue(computationCoGroup.getInput2().getTempMode().isCached());
+                       
+                       assertEquals(new FieldList(0), 
computationCoGroup.getInput2().getShipStrategyKeys());
+                       
+                       // check that the initial partitioning is pushed out of 
the loop
+                       assertEquals(ShipStrategyType.PARTITION_HASH, 
iteration.getInput1().getShipStrategy());
+                       assertEquals(new FieldList(0), 
iteration.getInput1().getShipStrategyKeys());
+
+               }
+               catch (Exception e) {
+                       System.err.println(e.getMessage());
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
+       
+       @SuppressWarnings("serial")
+       @Test
+       public void testPregelCompilerWithBroadcastVariable() {
+               try {
+                       final String BC_VAR_NAME = "borat variable";
+                       
+                       
+                       ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+                       env.setParallelism(DEFAULT_PARALLELISM);
+                       // compose test program
+                       {
+                               DataSet<Long> bcVar = env.fromElements(1L);
+
+                               DataSet<Vertex<Long, Long>> initialVertices = 
env.fromElements(
+                                               new Tuple2<>(1L, 1L), new 
Tuple2<>(2L, 2L))
+                                               .map(new 
Tuple2ToVertexMap<Long, Long>());
+
+                               DataSet<Edge<Long, NullValue>> edges = 
env.fromElements(new Tuple2<>(1L, 2L))
+                                               .map(new 
MapFunction<Tuple2<Long,Long>, Edge<Long, NullValue>>() {
+
+                                                       public Edge<Long, 
NullValue> map(Tuple2<Long, Long> edge) {
+                                                               return new 
Edge<>(edge.f0, edge.f1, NullValue.getInstance());
+                                                       }
+                                       });
+
+                               Graph<Long, Long, NullValue> graph = 
Graph.fromDataSet(initialVertices, edges, env);
+
+                               VertexCentricConfiguration parameters = new 
VertexCentricConfiguration();
+                               parameters.addBroadcastSet(BC_VAR_NAME, bcVar);
+
+                               DataSet<Vertex<Long, Long>> result = 
graph.runVertexCentricIteration(
+                                               new CCCompute(), null, 100, 
parameters)
+                                               .getVertices();
+                                       
+                               result.output(new 
DiscardingOutputFormat<Vertex<Long, Long>>());
+
+                       }
+                       
+                       Plan p = env.createProgramPlan("Pregel Connected 
Components");
+                       OptimizedPlan op = compileNoStats(p);
+                       
+                       // check the sink
+                       SinkPlanNode sink = op.getDataSinks().iterator().next();
+                       assertEquals(ShipStrategyType.FORWARD, 
sink.getInput().getShipStrategy());
+                       assertEquals(DEFAULT_PARALLELISM, 
sink.getParallelism());
+                       
+                       // check the iteration
+                       WorksetIterationPlanNode iteration = 
(WorksetIterationPlanNode) sink.getInput().getSource();
+                       assertEquals(DEFAULT_PARALLELISM, 
iteration.getParallelism());
+                       
+                       // check the solution set delta
+                       PlanNode ssDelta = 
iteration.getSolutionSetDeltaPlanNode();
+                       assertTrue(ssDelta instanceof SingleInputPlanNode);
+                       
+                       SingleInputPlanNode ssFlatMap = (SingleInputPlanNode) 
((SingleInputPlanNode) (ssDelta)).getInput().getSource();
+                       assertEquals(DEFAULT_PARALLELISM, 
ssFlatMap.getParallelism());
+                       assertEquals(ShipStrategyType.FORWARD, 
ssFlatMap.getInput().getShipStrategy());
+                       
+                       // check the computation coGroup
+                       DualInputPlanNode computationCoGroup = 
(DualInputPlanNode) (ssFlatMap.getInput().getSource());
+                       assertEquals(DEFAULT_PARALLELISM, 
computationCoGroup.getParallelism());
+                       assertEquals(ShipStrategyType.FORWARD, 
computationCoGroup.getInput1().getShipStrategy());
+                       assertEquals(ShipStrategyType.PARTITION_HASH, 
computationCoGroup.getInput2().getShipStrategy());
+                       
assertTrue(computationCoGroup.getInput2().getTempMode().isCached());
+                       
+                       assertEquals(new FieldList(0), 
computationCoGroup.getInput2().getShipStrategyKeys());
+                       
+                       // check that the initial partitioning is pushed out of 
the loop
+                       assertEquals(ShipStrategyType.PARTITION_HASH, 
iteration.getInput1().getShipStrategy());
+                       assertEquals(new FieldList(0), 
iteration.getInput1().getShipStrategyKeys());
+               }
+               catch (Exception e) {
+                       System.err.println(e.getMessage());
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
+
+       @SuppressWarnings("serial")
+       @Test
+       public void testPregelWithCombiner() {
+               try {
+                       ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+                       env.setParallelism(DEFAULT_PARALLELISM);
+                       // compose test program
+                       {
+
+                               DataSet<Vertex<Long, Long>> initialVertices = 
env.fromElements(
+                                               new Tuple2<>(1L, 1L), new 
Tuple2<>(2L, 2L))
+                                               .map(new 
Tuple2ToVertexMap<Long, Long>());
+
+                               DataSet<Edge<Long, NullValue>> edges = 
env.fromElements(new Tuple2<>(1L, 2L))
+                                       .map(new MapFunction<Tuple2<Long,Long>, 
Edge<Long, NullValue>>() {
+
+                                               public Edge<Long, NullValue> 
map(Tuple2<Long, Long> edge) {
+                                                       return new 
Edge<>(edge.f0, edge.f1, NullValue.getInstance());
+                                               }
+                               });
+
+                               Graph<Long, Long, NullValue> graph = 
Graph.fromDataSet(initialVertices, edges, env);
+                               
+                               DataSet<Vertex<Long, Long>> result = 
graph.runVertexCentricIteration(
+                                               new CCCompute(), new 
CCCombiner(), 100).getVertices();
+                               
+                               result.output(new 
DiscardingOutputFormat<Vertex<Long, Long>>());
+                       }
+                       
+                       Plan p = env.createProgramPlan("Pregel Connected 
Components");
+                       OptimizedPlan op = compileNoStats(p);
+                       
+                       // check the sink
+                       SinkPlanNode sink = op.getDataSinks().iterator().next();
+                       assertEquals(ShipStrategyType.FORWARD, 
sink.getInput().getShipStrategy());
+                       assertEquals(DEFAULT_PARALLELISM, 
sink.getParallelism());
+                       
+                       // check the iteration
+                       WorksetIterationPlanNode iteration = 
(WorksetIterationPlanNode) sink.getInput().getSource();
+                       assertEquals(DEFAULT_PARALLELISM, 
iteration.getParallelism());
+
+                       // check the combiner
+                       SingleInputPlanNode combiner = (SingleInputPlanNode) 
iteration.getInput2().getSource();
+                       assertEquals(ShipStrategyType.FORWARD, 
combiner.getInput().getShipStrategy());
+                       
+                       // check the solution set delta
+                       PlanNode ssDelta = 
iteration.getSolutionSetDeltaPlanNode();
+                       assertTrue(ssDelta instanceof SingleInputPlanNode);
+                       
+                       SingleInputPlanNode ssFlatMap = (SingleInputPlanNode) 
((SingleInputPlanNode) (ssDelta)).getInput().getSource();
+                       assertEquals(DEFAULT_PARALLELISM, 
ssFlatMap.getParallelism());
+                       assertEquals(ShipStrategyType.FORWARD, 
ssFlatMap.getInput().getShipStrategy());
+                       
+                       // check the computation coGroup
+                       DualInputPlanNode computationCoGroup = 
(DualInputPlanNode) (ssFlatMap.getInput().getSource());
+                       assertEquals(DEFAULT_PARALLELISM, 
computationCoGroup.getParallelism());
+                       assertEquals(ShipStrategyType.FORWARD, 
computationCoGroup.getInput1().getShipStrategy());
+                       assertEquals(ShipStrategyType.PARTITION_HASH, 
computationCoGroup.getInput2().getShipStrategy());
+                       
assertTrue(computationCoGroup.getInput2().getTempMode().isCached());
+                       
+                       assertEquals(new FieldList(0), 
computationCoGroup.getInput2().getShipStrategyKeys());
+                       
+                       // check that the initial partitioning is pushed out of 
the loop
+                       assertEquals(ShipStrategyType.PARTITION_HASH, 
iteration.getInput1().getShipStrategy());
+                       assertEquals(new FieldList(0), 
iteration.getInput1().getShipStrategyKeys());
+
+               }
+               catch (Exception e) {
+                       System.err.println(e.getMessage());
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
+
+       @SuppressWarnings("serial")
+       private static final class CCCompute extends ComputeFunction<Long, 
Long, NullValue, Long> {
+
+               @Override
+               public void compute(Vertex<Long, Long> vertex, 
MessageIterator<Long> messages) throws Exception {
+                       long currentComponent = vertex.getValue();
+
+                       for (Long msg : messages) {
+                               currentComponent = Math.min(currentComponent, 
msg);
+                       }
+
+                       if ((getSuperstepNumber() == 1) || (currentComponent < 
vertex.getValue())) {
+                               setNewVertexValue(currentComponent);
+                               for (Edge<Long, NullValue> edge: getEdges()) {
+                                       sendMessageTo(edge.getTarget(), 
currentComponent);
+                               }
+                       }
+               }
+       }
+
+       @SuppressWarnings("serial")
+       public static final class CCCombiner extends MessageCombiner<Long, 
Long> {
+
+               public void combineMessages(MessageIterator<Long> messages) {
+
+                       long minMessage = Long.MAX_VALUE;
+                       for (Long msg: messages) {
+                               minMessage = Math.min(minMessage, msg);
+                       }
+                       sendCombinedMessage(minMessage);
+               }
+       }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cc1a7979/flink-libraries/flink-gelly/src/test/java/org/apache/flink/graph/pregel/PregelTranslationTest.java
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-gelly/src/test/java/org/apache/flink/graph/pregel/PregelTranslationTest.java
 
b/flink-libraries/flink-gelly/src/test/java/org/apache/flink/graph/pregel/PregelTranslationTest.java
new file mode 100644
index 0000000..8f9552f
--- /dev/null
+++ 
b/flink-libraries/flink-gelly/src/test/java/org/apache/flink/graph/pregel/PregelTranslationTest.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.flink.graph.pregel;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import org.junit.Test;
+import org.apache.flink.api.common.aggregators.LongSumAggregator;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.io.DiscardingOutputFormat;
+import org.apache.flink.api.java.operators.DeltaIteration;
+import org.apache.flink.api.java.operators.DeltaIterationResultSet;
+import org.apache.flink.api.java.operators.SingleInputUdfOperator;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.operators.TwoInputUdfOperator;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.graph.Graph;
+import org.apache.flink.graph.Vertex;
+import org.apache.flink.types.NullValue;
+
+@SuppressWarnings("serial")
+public class PregelTranslationTest {
+
+       @Test
+       public void testTranslationPlainEdges() {
+               try {
+                       final String ITERATION_NAME = "Test Name";
+                       
+                       final String AGGREGATOR_NAME = "AggregatorName";
+                       
+                       final String BC_SET_NAME = "borat messages";
+
+                       final int NUM_ITERATIONS = 13;
+                       
+                       final int ITERATION_parallelism = 77;
+                       
+                       
+                       ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+                       
+                       DataSet<Long> bcVar = env.fromElements(1L);
+                       
+                       DataSet<Vertex<String, Double>> result;
+                       
+                       // ------------ construct the test program 
------------------
+                       {
+                               
+                               DataSet<Tuple2<String, Double>> initialVertices 
= env.fromElements(new Tuple2<>("abc", 3.44));
+
+                               DataSet<Tuple2<String, String>> edges = 
env.fromElements(new Tuple2<>("a", "c"));
+
+                               Graph<String, Double, NullValue> graph = 
Graph.fromTupleDataSet(initialVertices,
+                                               edges.map(new 
MapFunction<Tuple2<String,String>, Tuple3<String, String, NullValue>>() {
+
+                                                       public Tuple3<String, 
String, NullValue> map(
+                                                                       
Tuple2<String, String> edge) {
+                                                               return new 
Tuple3<>(edge.f0, edge.f1, NullValue.getInstance());
+                                                       }
+                                               }), env);
+
+                               VertexCentricConfiguration parameters = new 
VertexCentricConfiguration();
+
+                               parameters.addBroadcastSet(BC_SET_NAME, bcVar);
+                               parameters.setName(ITERATION_NAME);
+                               
parameters.setParallelism(ITERATION_parallelism);
+                               parameters.registerAggregator(AGGREGATOR_NAME, 
new LongSumAggregator());
+
+                               result = graph.runVertexCentricIteration(new 
MyCompute(), null,
+                                               NUM_ITERATIONS, 
parameters).getVertices();
+
+                               result.output(new 
DiscardingOutputFormat<Vertex<String, Double>>());
+                       }
+                       
+                       
+                       // ------------- validate the java program 
----------------
+                       
+                       assertTrue(result instanceof DeltaIterationResultSet);
+                       
+                       DeltaIterationResultSet<?, ?> resultSet = 
(DeltaIterationResultSet<?, ?>) result;
+                       DeltaIteration<?, ?> iteration = 
resultSet.getIterationHead();
+                       
+                       // check the basic iteration properties
+                       assertEquals(NUM_ITERATIONS, 
resultSet.getMaxIterations());
+                       assertArrayEquals(new int[] {0}, 
resultSet.getKeyPositions());
+                       assertEquals(ITERATION_parallelism, 
iteration.getParallelism());
+                       assertEquals(ITERATION_NAME, iteration.getName());
+                       
+                       assertEquals(AGGREGATOR_NAME, 
iteration.getAggregators().getAllRegisteredAggregators().iterator().next().getName());
+                       
+                       TwoInputUdfOperator<?, ?, ?, ?> computationCoGroup =
+                                       (TwoInputUdfOperator<?, ?, ?, ?>) 
((SingleInputUdfOperator<?, ?, ?>) resultSet.getNextWorkset()).getInput();
+                       
+                       // validate that the broadcast sets are forwarded
+                       assertEquals(bcVar, 
computationCoGroup.getBroadcastSets().get(BC_SET_NAME));
+               }
+               catch (Exception e) {
+                       System.err.println(e.getMessage());
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
+       
+       // 
--------------------------------------------------------------------------------------------
+
+       private static final class MyCompute extends ComputeFunction<String, 
Double, NullValue, Double> {
+
+               @Override
+               public void compute(Vertex<String, Double> vertex,
+                               MessageIterator<Double> messages) throws 
Exception {}
+       }
+}

Reply via email to