http://git-wip-us.apache.org/repos/asf/flink/blob/cc1a7979/flink-libraries/flink-gelly/src/test/java/org/apache/flink/graph/pregel/PregelCompilerTest.java ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-gelly/src/test/java/org/apache/flink/graph/pregel/PregelCompilerTest.java b/flink-libraries/flink-gelly/src/test/java/org/apache/flink/graph/pregel/PregelCompilerTest.java new file mode 100644 index 0000000..04f0ca4 --- /dev/null +++ b/flink-libraries/flink-gelly/src/test/java/org/apache/flink/graph/pregel/PregelCompilerTest.java @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.graph.pregel; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import org.apache.flink.api.common.Plan; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.operators.util.FieldList; +import org.apache.flink.api.java.io.DiscardingOutputFormat; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.optimizer.util.CompilerTestBase; +import org.junit.Test; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.graph.Edge; +import org.apache.flink.graph.Graph; +import org.apache.flink.graph.Vertex; +import org.apache.flink.optimizer.plan.DualInputPlanNode; +import org.apache.flink.optimizer.plan.OptimizedPlan; +import org.apache.flink.optimizer.plan.PlanNode; +import org.apache.flink.optimizer.plan.SingleInputPlanNode; +import org.apache.flink.optimizer.plan.SinkPlanNode; +import org.apache.flink.optimizer.plan.WorksetIterationPlanNode; +import org.apache.flink.runtime.operators.shipping.ShipStrategyType; +import org.apache.flink.types.NullValue; +import org.apache.flink.graph.utils.Tuple2ToVertexMap; + +public class PregelCompilerTest extends CompilerTestBase { + + private static final long serialVersionUID = 1L; + + @SuppressWarnings("serial") + @Test + public void testPregelCompiler() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(DEFAULT_PARALLELISM); + // compose test program + { + + DataSet<Vertex<Long, Long>> initialVertices = env.fromElements( + new Tuple2<>(1L, 1L), new Tuple2<>(2L, 2L)) + .map(new Tuple2ToVertexMap<Long, Long>()); + + DataSet<Edge<Long, NullValue>> edges = env.fromElements(new Tuple2<>(1L, 2L)) + .map(new MapFunction<Tuple2<Long,Long>, Edge<Long, NullValue>>() { + + public Edge<Long, NullValue> map(Tuple2<Long, Long> edge) { + return new Edge<>(edge.f0, edge.f1, NullValue.getInstance()); + } + }); + + Graph<Long, Long, NullValue> graph = Graph.fromDataSet(initialVertices, edges, env); + + DataSet<Vertex<Long, Long>> result = graph.runVertexCentricIteration( + new CCCompute(), null, 100).getVertices(); + + result.output(new DiscardingOutputFormat<Vertex<Long, Long>>()); + } + + Plan p = env.createProgramPlan("Pregel Connected Components"); + OptimizedPlan op = compileNoStats(p); + + // check the sink + SinkPlanNode sink = op.getDataSinks().iterator().next(); + assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); + assertEquals(DEFAULT_PARALLELISM, sink.getParallelism()); + + // check the iteration + WorksetIterationPlanNode iteration = (WorksetIterationPlanNode) sink.getInput().getSource(); + assertEquals(DEFAULT_PARALLELISM, iteration.getParallelism()); + + // check the solution set delta + PlanNode ssDelta = iteration.getSolutionSetDeltaPlanNode(); + assertTrue(ssDelta instanceof SingleInputPlanNode); + + SingleInputPlanNode ssFlatMap = (SingleInputPlanNode) ((SingleInputPlanNode) (ssDelta)).getInput().getSource(); + assertEquals(DEFAULT_PARALLELISM, ssFlatMap.getParallelism()); + assertEquals(ShipStrategyType.FORWARD, ssFlatMap.getInput().getShipStrategy()); + + // check the computation coGroup + DualInputPlanNode computationCoGroup = (DualInputPlanNode) (ssFlatMap.getInput().getSource()); + assertEquals(DEFAULT_PARALLELISM, computationCoGroup.getParallelism()); + assertEquals(ShipStrategyType.FORWARD, computationCoGroup.getInput1().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_HASH, computationCoGroup.getInput2().getShipStrategy()); + assertTrue(computationCoGroup.getInput2().getTempMode().isCached()); + + assertEquals(new FieldList(0), computationCoGroup.getInput2().getShipStrategyKeys()); + + // check that the initial partitioning is pushed out of the loop + assertEquals(ShipStrategyType.PARTITION_HASH, iteration.getInput1().getShipStrategy()); + assertEquals(new FieldList(0), iteration.getInput1().getShipStrategyKeys()); + + } + catch (Exception e) { + System.err.println(e.getMessage()); + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @SuppressWarnings("serial") + @Test + public void testPregelCompilerWithBroadcastVariable() { + try { + final String BC_VAR_NAME = "borat variable"; + + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(DEFAULT_PARALLELISM); + // compose test program + { + DataSet<Long> bcVar = env.fromElements(1L); + + DataSet<Vertex<Long, Long>> initialVertices = env.fromElements( + new Tuple2<>(1L, 1L), new Tuple2<>(2L, 2L)) + .map(new Tuple2ToVertexMap<Long, Long>()); + + DataSet<Edge<Long, NullValue>> edges = env.fromElements(new Tuple2<>(1L, 2L)) + .map(new MapFunction<Tuple2<Long,Long>, Edge<Long, NullValue>>() { + + public Edge<Long, NullValue> map(Tuple2<Long, Long> edge) { + return new Edge<>(edge.f0, edge.f1, NullValue.getInstance()); + } + }); + + Graph<Long, Long, NullValue> graph = Graph.fromDataSet(initialVertices, edges, env); + + VertexCentricConfiguration parameters = new VertexCentricConfiguration(); + parameters.addBroadcastSet(BC_VAR_NAME, bcVar); + + DataSet<Vertex<Long, Long>> result = graph.runVertexCentricIteration( + new CCCompute(), null, 100, parameters) + .getVertices(); + + result.output(new DiscardingOutputFormat<Vertex<Long, Long>>()); + + } + + Plan p = env.createProgramPlan("Pregel Connected Components"); + OptimizedPlan op = compileNoStats(p); + + // check the sink + SinkPlanNode sink = op.getDataSinks().iterator().next(); + assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); + assertEquals(DEFAULT_PARALLELISM, sink.getParallelism()); + + // check the iteration + WorksetIterationPlanNode iteration = (WorksetIterationPlanNode) sink.getInput().getSource(); + assertEquals(DEFAULT_PARALLELISM, iteration.getParallelism()); + + // check the solution set delta + PlanNode ssDelta = iteration.getSolutionSetDeltaPlanNode(); + assertTrue(ssDelta instanceof SingleInputPlanNode); + + SingleInputPlanNode ssFlatMap = (SingleInputPlanNode) ((SingleInputPlanNode) (ssDelta)).getInput().getSource(); + assertEquals(DEFAULT_PARALLELISM, ssFlatMap.getParallelism()); + assertEquals(ShipStrategyType.FORWARD, ssFlatMap.getInput().getShipStrategy()); + + // check the computation coGroup + DualInputPlanNode computationCoGroup = (DualInputPlanNode) (ssFlatMap.getInput().getSource()); + assertEquals(DEFAULT_PARALLELISM, computationCoGroup.getParallelism()); + assertEquals(ShipStrategyType.FORWARD, computationCoGroup.getInput1().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_HASH, computationCoGroup.getInput2().getShipStrategy()); + assertTrue(computationCoGroup.getInput2().getTempMode().isCached()); + + assertEquals(new FieldList(0), computationCoGroup.getInput2().getShipStrategyKeys()); + + // check that the initial partitioning is pushed out of the loop + assertEquals(ShipStrategyType.PARTITION_HASH, iteration.getInput1().getShipStrategy()); + assertEquals(new FieldList(0), iteration.getInput1().getShipStrategyKeys()); + } + catch (Exception e) { + System.err.println(e.getMessage()); + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @SuppressWarnings("serial") + @Test + public void testPregelWithCombiner() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(DEFAULT_PARALLELISM); + // compose test program + { + + DataSet<Vertex<Long, Long>> initialVertices = env.fromElements( + new Tuple2<>(1L, 1L), new Tuple2<>(2L, 2L)) + .map(new Tuple2ToVertexMap<Long, Long>()); + + DataSet<Edge<Long, NullValue>> edges = env.fromElements(new Tuple2<>(1L, 2L)) + .map(new MapFunction<Tuple2<Long,Long>, Edge<Long, NullValue>>() { + + public Edge<Long, NullValue> map(Tuple2<Long, Long> edge) { + return new Edge<>(edge.f0, edge.f1, NullValue.getInstance()); + } + }); + + Graph<Long, Long, NullValue> graph = Graph.fromDataSet(initialVertices, edges, env); + + DataSet<Vertex<Long, Long>> result = graph.runVertexCentricIteration( + new CCCompute(), new CCCombiner(), 100).getVertices(); + + result.output(new DiscardingOutputFormat<Vertex<Long, Long>>()); + } + + Plan p = env.createProgramPlan("Pregel Connected Components"); + OptimizedPlan op = compileNoStats(p); + + // check the sink + SinkPlanNode sink = op.getDataSinks().iterator().next(); + assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); + assertEquals(DEFAULT_PARALLELISM, sink.getParallelism()); + + // check the iteration + WorksetIterationPlanNode iteration = (WorksetIterationPlanNode) sink.getInput().getSource(); + assertEquals(DEFAULT_PARALLELISM, iteration.getParallelism()); + + // check the combiner + SingleInputPlanNode combiner = (SingleInputPlanNode) iteration.getInput2().getSource(); + assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy()); + + // check the solution set delta + PlanNode ssDelta = iteration.getSolutionSetDeltaPlanNode(); + assertTrue(ssDelta instanceof SingleInputPlanNode); + + SingleInputPlanNode ssFlatMap = (SingleInputPlanNode) ((SingleInputPlanNode) (ssDelta)).getInput().getSource(); + assertEquals(DEFAULT_PARALLELISM, ssFlatMap.getParallelism()); + assertEquals(ShipStrategyType.FORWARD, ssFlatMap.getInput().getShipStrategy()); + + // check the computation coGroup + DualInputPlanNode computationCoGroup = (DualInputPlanNode) (ssFlatMap.getInput().getSource()); + assertEquals(DEFAULT_PARALLELISM, computationCoGroup.getParallelism()); + assertEquals(ShipStrategyType.FORWARD, computationCoGroup.getInput1().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_HASH, computationCoGroup.getInput2().getShipStrategy()); + assertTrue(computationCoGroup.getInput2().getTempMode().isCached()); + + assertEquals(new FieldList(0), computationCoGroup.getInput2().getShipStrategyKeys()); + + // check that the initial partitioning is pushed out of the loop + assertEquals(ShipStrategyType.PARTITION_HASH, iteration.getInput1().getShipStrategy()); + assertEquals(new FieldList(0), iteration.getInput1().getShipStrategyKeys()); + + } + catch (Exception e) { + System.err.println(e.getMessage()); + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @SuppressWarnings("serial") + private static final class CCCompute extends ComputeFunction<Long, Long, NullValue, Long> { + + @Override + public void compute(Vertex<Long, Long> vertex, MessageIterator<Long> messages) throws Exception { + long currentComponent = vertex.getValue(); + + for (Long msg : messages) { + currentComponent = Math.min(currentComponent, msg); + } + + if ((getSuperstepNumber() == 1) || (currentComponent < vertex.getValue())) { + setNewVertexValue(currentComponent); + for (Edge<Long, NullValue> edge: getEdges()) { + sendMessageTo(edge.getTarget(), currentComponent); + } + } + } + } + + @SuppressWarnings("serial") + public static final class CCCombiner extends MessageCombiner<Long, Long> { + + public void combineMessages(MessageIterator<Long> messages) { + + long minMessage = Long.MAX_VALUE; + for (Long msg: messages) { + minMessage = Math.min(minMessage, msg); + } + sendCombinedMessage(minMessage); + } + } + +}
http://git-wip-us.apache.org/repos/asf/flink/blob/cc1a7979/flink-libraries/flink-gelly/src/test/java/org/apache/flink/graph/pregel/PregelTranslationTest.java ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-gelly/src/test/java/org/apache/flink/graph/pregel/PregelTranslationTest.java b/flink-libraries/flink-gelly/src/test/java/org/apache/flink/graph/pregel/PregelTranslationTest.java new file mode 100644 index 0000000..8f9552f --- /dev/null +++ b/flink-libraries/flink-gelly/src/test/java/org/apache/flink/graph/pregel/PregelTranslationTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.flink.graph.pregel; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import org.junit.Test; +import org.apache.flink.api.common.aggregators.LongSumAggregator; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.io.DiscardingOutputFormat; +import org.apache.flink.api.java.operators.DeltaIteration; +import org.apache.flink.api.java.operators.DeltaIterationResultSet; +import org.apache.flink.api.java.operators.SingleInputUdfOperator; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.operators.TwoInputUdfOperator; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.graph.Graph; +import org.apache.flink.graph.Vertex; +import org.apache.flink.types.NullValue; + +@SuppressWarnings("serial") +public class PregelTranslationTest { + + @Test + public void testTranslationPlainEdges() { + try { + final String ITERATION_NAME = "Test Name"; + + final String AGGREGATOR_NAME = "AggregatorName"; + + final String BC_SET_NAME = "borat messages"; + + final int NUM_ITERATIONS = 13; + + final int ITERATION_parallelism = 77; + + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Long> bcVar = env.fromElements(1L); + + DataSet<Vertex<String, Double>> result; + + // ------------ construct the test program ------------------ + { + + DataSet<Tuple2<String, Double>> initialVertices = env.fromElements(new Tuple2<>("abc", 3.44)); + + DataSet<Tuple2<String, String>> edges = env.fromElements(new Tuple2<>("a", "c")); + + Graph<String, Double, NullValue> graph = Graph.fromTupleDataSet(initialVertices, + edges.map(new MapFunction<Tuple2<String,String>, Tuple3<String, String, NullValue>>() { + + public Tuple3<String, String, NullValue> map( + Tuple2<String, String> edge) { + return new Tuple3<>(edge.f0, edge.f1, NullValue.getInstance()); + } + }), env); + + VertexCentricConfiguration parameters = new VertexCentricConfiguration(); + + parameters.addBroadcastSet(BC_SET_NAME, bcVar); + parameters.setName(ITERATION_NAME); + parameters.setParallelism(ITERATION_parallelism); + parameters.registerAggregator(AGGREGATOR_NAME, new LongSumAggregator()); + + result = graph.runVertexCentricIteration(new MyCompute(), null, + NUM_ITERATIONS, parameters).getVertices(); + + result.output(new DiscardingOutputFormat<Vertex<String, Double>>()); + } + + + // ------------- validate the java program ---------------- + + assertTrue(result instanceof DeltaIterationResultSet); + + DeltaIterationResultSet<?, ?> resultSet = (DeltaIterationResultSet<?, ?>) result; + DeltaIteration<?, ?> iteration = resultSet.getIterationHead(); + + // check the basic iteration properties + assertEquals(NUM_ITERATIONS, resultSet.getMaxIterations()); + assertArrayEquals(new int[] {0}, resultSet.getKeyPositions()); + assertEquals(ITERATION_parallelism, iteration.getParallelism()); + assertEquals(ITERATION_NAME, iteration.getName()); + + assertEquals(AGGREGATOR_NAME, iteration.getAggregators().getAllRegisteredAggregators().iterator().next().getName()); + + TwoInputUdfOperator<?, ?, ?, ?> computationCoGroup = + (TwoInputUdfOperator<?, ?, ?, ?>) ((SingleInputUdfOperator<?, ?, ?>) resultSet.getNextWorkset()).getInput(); + + // validate that the broadcast sets are forwarded + assertEquals(bcVar, computationCoGroup.getBroadcastSets().get(BC_SET_NAME)); + } + catch (Exception e) { + System.err.println(e.getMessage()); + e.printStackTrace(); + fail(e.getMessage()); + } + } + + // -------------------------------------------------------------------------------------------- + + private static final class MyCompute extends ComputeFunction<String, Double, NullValue, Double> { + + @Override + public void compute(Vertex<String, Double> vertex, + MessageIterator<Double> messages) throws Exception {} + } +}
