Repository: flink Updated Branches: refs/heads/master a9dc4307d -> a137321ac
[tests] Remove subsumes low-level API tests for broadcast variabes - BroadcastVarsNepheleITCase is subsumed by BroadcastBranchingITCase - KMeansIterativeNepheleITCase is subsumed by the Java/Scala API KMeans example programs Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/1b975058 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/1b975058 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/1b975058 Branch: refs/heads/master Commit: 1b9750588a1ef0c4767ebad8aafaf50e67f5e7ee Parents: a9dc430 Author: Stephan Ewen <[email protected]> Authored: Sun Jun 14 15:26:53 2015 +0200 Committer: Stephan Ewen <[email protected]> Committed: Wed Jul 1 16:09:00 2015 +0200 ---------------------------------------------------------------------- .../BroadcastVarsNepheleITCase.java | 340 ------------------- .../KMeansIterativeNepheleITCase.java | 324 ------------------ 2 files changed, 664 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/1b975058/flink-tests/src/test/java/org/apache/flink/test/broadcastvars/BroadcastVarsNepheleITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/broadcastvars/BroadcastVarsNepheleITCase.java b/flink-tests/src/test/java/org/apache/flink/test/broadcastvars/BroadcastVarsNepheleITCase.java deleted file mode 100644 index 9366058..0000000 --- a/flink-tests/src/test/java/org/apache/flink/test/broadcastvars/BroadcastVarsNepheleITCase.java +++ /dev/null @@ -1,340 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.test.broadcastvars; - -import java.io.BufferedReader; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.Random; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -import org.apache.flink.api.common.operators.util.UserCodeClassWrapper; -import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; -import org.apache.flink.api.common.typeutils.TypeSerializerFactory; -import org.apache.flink.api.common.typeutils.record.RecordSerializerFactory; -import org.apache.flink.api.java.record.functions.MapFunction; -import org.apache.flink.api.java.record.io.CsvInputFormat; -import org.apache.flink.api.java.record.io.CsvOutputFormat; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.core.fs.Path; -import org.apache.flink.runtime.jobgraph.JobVertex; -import org.apache.flink.runtime.jobgraph.DistributionPattern; -import org.apache.flink.runtime.jobgraph.InputFormatVertex; -import org.apache.flink.runtime.jobgraph.JobGraph; -import org.apache.flink.runtime.jobgraph.OutputFormatVertex; -import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup; -import org.apache.flink.runtime.operators.CollectorMapDriver; -import org.apache.flink.runtime.operators.DriverStrategy; -import org.apache.flink.runtime.operators.RegularPactTask; -import org.apache.flink.runtime.operators.shipping.ShipStrategyType; -import org.apache.flink.runtime.operators.util.LocalStrategy; -import org.apache.flink.runtime.operators.util.TaskConfig; -import org.apache.flink.test.iterative.nephele.JobGraphUtils; -import org.apache.flink.test.util.RecordAPITestBase; -import org.apache.flink.types.LongValue; -import org.apache.flink.types.Record; -import org.apache.flink.util.Collector; -import org.junit.Assert; - -@SuppressWarnings("deprecation") -public class BroadcastVarsNepheleITCase extends RecordAPITestBase { - - private static final long SEED_POINTS = 0xBADC0FFEEBEEFL; - - private static final long SEED_MODELS = 0x39134230AFF32L; - - private static final int NUM_POINTS = 10000; - - private static final int NUM_MODELS = 42; - - private static final int NUM_FEATURES = 3; - - private static final int parallelism = 4; - - protected String pointsPath; - - protected String modelsPath; - - protected String resultPath; - - public BroadcastVarsNepheleITCase(){ - setTaskManagerNumSlots(parallelism); - } - - - public static final String getInputPoints(int numPoints, int numDimensions, long seed) { - if (numPoints < 1 || numPoints > 1000000) - throw new IllegalArgumentException(); - - Random r = new Random(); - - StringBuilder bld = new StringBuilder(3 * (1 + numDimensions) * numPoints); - for (int i = 1; i <= numPoints; i++) { - bld.append(i); - bld.append(' '); - - r.setSeed(seed + 1000 * i); - for (int j = 1; j <= numDimensions; j++) { - bld.append(r.nextInt(1000)); - bld.append(' '); - } - bld.append('\n'); - } - return bld.toString(); - } - - public static final String getInputModels(int numModels, int numDimensions, long seed) { - if (numModels < 1 || numModels > 100) - throw new IllegalArgumentException(); - - Random r = new Random(); - - StringBuilder bld = new StringBuilder(3 * (1 + numDimensions) * numModels); - for (int i = 1; i <= numModels; i++) { - bld.append(i); - bld.append(' '); - - r.setSeed(seed + 1000 * i); - for (int j = 1; j <= numDimensions; j++) { - bld.append(r.nextInt(100)); - bld.append(' '); - } - bld.append('\n'); - } - return bld.toString(); - } - - @Override - protected void preSubmit() throws Exception { - this.pointsPath = createTempFile("points.txt", getInputPoints(NUM_POINTS, NUM_FEATURES, SEED_POINTS)); - this.modelsPath = createTempFile("models.txt", getInputModels(NUM_MODELS, NUM_FEATURES, SEED_MODELS)); - this.resultPath = getTempFilePath("results"); - } - - @Override - protected JobGraph getJobGraph() throws Exception { - return createJobGraphV1(this.pointsPath, this.modelsPath, this.resultPath, parallelism); - } - - @Override - protected void postSubmit() throws Exception { - final Random randPoints = new Random(); - final Random randModels = new Random(); - final Pattern p = Pattern.compile("(\\d+) (\\d+) (\\d+)"); - - long [][] results = new long[NUM_POINTS][NUM_MODELS]; - boolean [][] occurs = new boolean[NUM_POINTS][NUM_MODELS]; - for (int i = 0; i < NUM_POINTS; i++) { - for (int j = 0; j < NUM_MODELS; j++) { - long actDotProd = 0; - randPoints.setSeed(SEED_POINTS + 1000 * (i+1)); - randModels.setSeed(SEED_MODELS + 1000 * (j+1)); - for (int z = 1; z <= NUM_FEATURES; z++) { - actDotProd += randPoints.nextInt(1000) * randModels.nextInt(100); - } - results[i][j] = actDotProd; - occurs[i][j] = false; - } - } - - for (BufferedReader reader : getResultReader(this.resultPath)) { - String line = null; - while (null != (line = reader.readLine())) { - final Matcher m = p.matcher(line); - Assert.assertTrue(m.matches()); - - int modelId = Integer.parseInt(m.group(1)); - int pointId = Integer.parseInt(m.group(2)); - long expDotProd = Long.parseLong(m.group(3)); - - Assert.assertFalse("Dot product for record (" + pointId + ", " + modelId + ") occurs more than once", occurs[pointId-1][modelId-1]); - Assert.assertEquals(String.format("Bad product for (%04d, %04d)", pointId, modelId), expDotProd, results[pointId-1][modelId-1]); - - occurs[pointId-1][modelId-1] = true; - } - } - - for (int i = 0; i < NUM_POINTS; i++) { - for (int j = 0; j < NUM_MODELS; j++) { - Assert.assertTrue("Dot product for record (" + (i+1) + ", " + (j+1) + ") does not occur", occurs[i][j]); - } - } - } - - // ------------------------------------------------------------------------------------------------------------- - // UDFs - // ------------------------------------------------------------------------------------------------------------- - - public static final class DotProducts extends MapFunction { - - private static final long serialVersionUID = 1L; - - private final Record result = new Record(3); - - private final LongValue lft = new LongValue(); - - private final LongValue rgt = new LongValue(); - - private final LongValue prd = new LongValue(); - - private Collection<Record> models; - - @Override - public void open(Configuration parameters) throws Exception { - List<Record> shared = this.getRuntimeContext().getBroadcastVariable("models"); - this.models = new ArrayList<Record>(shared.size()); - synchronized (shared) { - for (Record r : shared) { - this.models.add(r.createCopy()); - } - } - } - - @Override - public void map(Record record, Collector<Record> out) throws Exception { - - for (Record model : this.models) { - // compute dot product between model and pair - long product = 0; - for (int i = 1; i <= NUM_FEATURES; i++) { - product += model.getField(i, this.lft).getValue() * record.getField(i, this.rgt).getValue(); - } - this.prd.setValue(product); - - // construct result - this.result.copyFrom(model, new int[] { 0 }, new int[] { 0 }); - this.result.copyFrom(record, new int[] { 0 }, new int[] { 1 }); - this.result.setField(2, this.prd); - - // emit result - out.collect(this.result); - } - } - } - - // ------------------------------------------------------------------------------------------------------------- - // Job vertex builder methods - // ------------------------------------------------------------------------------------------------------------- - - @SuppressWarnings("unchecked") - private static InputFormatVertex createPointsInput(JobGraph jobGraph, String pointsPath, int numSubTasks, TypeSerializerFactory<?> serializer) { - CsvInputFormat pointsInFormat = new CsvInputFormat(' ', LongValue.class, LongValue.class, LongValue.class, LongValue.class); - InputFormatVertex pointsInput = JobGraphUtils.createInput(pointsInFormat, pointsPath, "Input[Points]", jobGraph, numSubTasks); - - { - TaskConfig taskConfig = new TaskConfig(pointsInput.getConfiguration()); - taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD); - taskConfig.setOutputSerializer(serializer); - } - - return pointsInput; - } - - @SuppressWarnings("unchecked") - private static InputFormatVertex createModelsInput(JobGraph jobGraph, String pointsPath, int numSubTasks, TypeSerializerFactory<?> serializer) { - CsvInputFormat modelsInFormat = new CsvInputFormat(' ', LongValue.class, LongValue.class, LongValue.class, LongValue.class); - InputFormatVertex modelsInput = JobGraphUtils.createInput(modelsInFormat, pointsPath, "Input[Models]", jobGraph, numSubTasks); - - { - TaskConfig taskConfig = new TaskConfig(modelsInput.getConfiguration()); - taskConfig.addOutputShipStrategy(ShipStrategyType.BROADCAST); - taskConfig.setOutputSerializer(serializer); - } - - return modelsInput; - } - - private static JobVertex createMapper(JobGraph jobGraph, int numSubTasks, TypeSerializerFactory<?> serializer) { - JobVertex pointsInput = JobGraphUtils.createTask(RegularPactTask.class, "Map[DotProducts]", jobGraph, numSubTasks); - - { - TaskConfig taskConfig = new TaskConfig(pointsInput.getConfiguration()); - - taskConfig.setStubWrapper(new UserCodeClassWrapper<DotProducts>(DotProducts.class)); - taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD); - taskConfig.setOutputSerializer(serializer); - taskConfig.setDriver(CollectorMapDriver.class); - taskConfig.setDriverStrategy(DriverStrategy.COLLECTOR_MAP); - - taskConfig.addInputToGroup(0); - taskConfig.setInputLocalStrategy(0, LocalStrategy.NONE); - taskConfig.setInputSerializer(serializer, 0); - - taskConfig.setBroadcastInputName("models", 0); - taskConfig.addBroadcastInputToGroup(0); - taskConfig.setBroadcastInputSerializer(serializer, 0); - } - - return pointsInput; - } - - private static OutputFormatVertex createOutput(JobGraph jobGraph, String resultPath, int numSubTasks, TypeSerializerFactory<?> serializer) { - OutputFormatVertex output = JobGraphUtils.createFileOutput(jobGraph, "Output", numSubTasks); - - { - TaskConfig taskConfig = new TaskConfig(output.getConfiguration()); - taskConfig.addInputToGroup(0); - taskConfig.setInputSerializer(serializer, 0); - - @SuppressWarnings("unchecked") - CsvOutputFormat outFormat = new CsvOutputFormat("\n", " ", LongValue.class, LongValue.class, LongValue.class); - outFormat.setOutputFilePath(new Path(resultPath)); - - taskConfig.setStubWrapper(new UserCodeObjectWrapper<CsvOutputFormat>(outFormat)); - } - - return output; - } - - // ------------------------------------------------------------------------------------------------------------- - // Unified solution set and workset tail update - // ------------------------------------------------------------------------------------------------------------- - - private JobGraph createJobGraphV1(String pointsPath, String centersPath, String resultPath, int numSubTasks) { - - // -- init ------------------------------------------------------------------------------------------------- - final TypeSerializerFactory<?> serializer = RecordSerializerFactory.get(); - - JobGraph jobGraph = new JobGraph("Distance Builder"); - - // -- vertices --------------------------------------------------------------------------------------------- - InputFormatVertex points = createPointsInput(jobGraph, pointsPath, numSubTasks, serializer); - InputFormatVertex models = createModelsInput(jobGraph, centersPath, numSubTasks, serializer); - JobVertex mapper = createMapper(jobGraph, numSubTasks, serializer); - OutputFormatVertex output = createOutput(jobGraph, resultPath, numSubTasks, serializer); - - // -- edges ------------------------------------------------------------------------------------------------ - JobGraphUtils.connect(points, mapper, DistributionPattern.POINTWISE); - JobGraphUtils.connect(models, mapper, DistributionPattern.ALL_TO_ALL); - JobGraphUtils.connect(mapper, output, DistributionPattern.POINTWISE); - - // -- instance sharing ------------------------------------------------------------------------------------- - - SlotSharingGroup sharing = new SlotSharingGroup(); - - points.setSlotSharingGroup(sharing); - models.setSlotSharingGroup(sharing); - mapper.setSlotSharingGroup(sharing); - output.setSlotSharingGroup(sharing); - - return jobGraph; - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/1b975058/flink-tests/src/test/java/org/apache/flink/test/broadcastvars/KMeansIterativeNepheleITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/broadcastvars/KMeansIterativeNepheleITCase.java b/flink-tests/src/test/java/org/apache/flink/test/broadcastvars/KMeansIterativeNepheleITCase.java deleted file mode 100644 index 61ba59a..0000000 --- a/flink-tests/src/test/java/org/apache/flink/test/broadcastvars/KMeansIterativeNepheleITCase.java +++ /dev/null @@ -1,324 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.test.broadcastvars; - -import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; -import org.apache.flink.api.common.typeutils.TypeComparatorFactory; -import org.apache.flink.api.common.typeutils.TypeSerializerFactory; -import org.apache.flink.api.common.typeutils.record.RecordComparatorFactory; -import org.apache.flink.api.common.typeutils.record.RecordSerializerFactory; -import org.apache.flink.api.java.record.io.CsvInputFormat; -import org.apache.flink.api.java.record.operators.ReduceOperator.WrappingReduceFunction; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.core.fs.Path; -import org.apache.flink.runtime.iterative.task.IterationHeadPactTask; -import org.apache.flink.runtime.iterative.task.IterationIntermediatePactTask; -import org.apache.flink.runtime.iterative.task.IterationTailPactTask; -import org.apache.flink.runtime.jobgraph.JobVertex; -import org.apache.flink.runtime.jobgraph.DistributionPattern; -import org.apache.flink.runtime.jobgraph.InputFormatVertex; -import org.apache.flink.runtime.jobgraph.JobGraph; -import org.apache.flink.runtime.jobgraph.OutputFormatVertex; -import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup; -import org.apache.flink.runtime.operators.CollectorMapDriver; -import org.apache.flink.runtime.operators.DriverStrategy; -import org.apache.flink.runtime.operators.GroupReduceDriver; -import org.apache.flink.runtime.operators.NoOpDriver; -import org.apache.flink.runtime.operators.chaining.ChainedCollectorMapDriver; -import org.apache.flink.runtime.operators.shipping.ShipStrategyType; -import org.apache.flink.runtime.operators.util.LocalStrategy; -import org.apache.flink.runtime.operators.util.TaskConfig; -import org.apache.flink.test.iterative.nephele.JobGraphUtils; -import org.apache.flink.test.recordJobs.kmeans.KMeansBroadcast.PointBuilder; -import org.apache.flink.test.recordJobs.kmeans.KMeansBroadcast.PointOutFormat; -import org.apache.flink.test.recordJobs.kmeans.KMeansBroadcast.RecomputeClusterCenter; -import org.apache.flink.test.recordJobs.kmeans.KMeansBroadcast.SelectNearestCenter; -import org.apache.flink.test.testdata.KMeansData; -import org.apache.flink.test.util.RecordAPITestBase; -import org.apache.flink.types.DoubleValue; -import org.apache.flink.types.IntValue; - -public class KMeansIterativeNepheleITCase extends RecordAPITestBase { - - private static final int ITERATION_ID = 42; - - private static final int MEMORY_PER_CONSUMER = 2; - - private static final int parallelism = 4; - - private static final double MEMORY_FRACTION_PER_CONSUMER = (double)MEMORY_PER_CONSUMER/TASK_MANAGER_MEMORY_SIZE*parallelism; - - protected String dataPath; - protected String clusterPath; - protected String resultPath; - - - public KMeansIterativeNepheleITCase() { - setTaskManagerNumSlots(parallelism); - } - - @Override - protected void preSubmit() throws Exception { - dataPath = createTempFile("datapoints.txt", KMeansData.DATAPOINTS); - clusterPath = createTempFile("initial_centers.txt", KMeansData.INITIAL_CENTERS); - resultPath = getTempDirPath("result"); - } - - @Override - protected void postSubmit() throws Exception { - compareResultsByLinesInMemory(KMeansData.CENTERS_AFTER_20_ITERATIONS_SINGLE_DIGIT, resultPath); - } - - @Override - protected JobGraph getJobGraph() throws Exception { - return createJobGraph(dataPath, clusterPath, this.resultPath, parallelism, 20); - } - - // ------------------------------------------------------------------------------------------------------------- - // Job vertex builder methods - // ------------------------------------------------------------------------------------------------------------- - - private static InputFormatVertex createPointsInput(JobGraph jobGraph, String pointsPath, int numSubTasks, TypeSerializerFactory<?> serializer) { - @SuppressWarnings("unchecked") - CsvInputFormat pointsInFormat = new CsvInputFormat('|', IntValue.class, DoubleValue.class, DoubleValue.class, DoubleValue.class); - InputFormatVertex pointsInput = JobGraphUtils.createInput(pointsInFormat, pointsPath, "[Points]", jobGraph, numSubTasks); - { - TaskConfig taskConfig = new TaskConfig(pointsInput.getConfiguration()); - taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD); - taskConfig.setOutputSerializer(serializer); - - TaskConfig chainedMapper = new TaskConfig(new Configuration()); - chainedMapper.setDriverStrategy(DriverStrategy.COLLECTOR_MAP); - chainedMapper.setStubWrapper(new UserCodeObjectWrapper<PointBuilder>(new PointBuilder())); - chainedMapper.addOutputShipStrategy(ShipStrategyType.FORWARD); - chainedMapper.setOutputSerializer(serializer); - - taskConfig.addChainedTask(ChainedCollectorMapDriver.class, chainedMapper, "Build points"); - } - - return pointsInput; - } - - private static InputFormatVertex createCentersInput(JobGraph jobGraph, String centersPath, int numSubTasks, TypeSerializerFactory<?> serializer) { - @SuppressWarnings("unchecked") - CsvInputFormat modelsInFormat = new CsvInputFormat('|', IntValue.class, DoubleValue.class, DoubleValue.class, DoubleValue.class); - InputFormatVertex modelsInput = JobGraphUtils.createInput(modelsInFormat, centersPath, "[Models]", jobGraph, numSubTasks); - - { - TaskConfig taskConfig = new TaskConfig(modelsInput.getConfiguration()); - taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD); - taskConfig.setOutputSerializer(serializer); - - TaskConfig chainedMapper = new TaskConfig(new Configuration()); - chainedMapper.setDriverStrategy(DriverStrategy.COLLECTOR_MAP); - chainedMapper.setStubWrapper(new UserCodeObjectWrapper<PointBuilder>(new PointBuilder())); - chainedMapper.addOutputShipStrategy(ShipStrategyType.FORWARD); - chainedMapper.setOutputSerializer(serializer); - - taskConfig.addChainedTask(ChainedCollectorMapDriver.class, chainedMapper, "Build centers"); - } - - return modelsInput; - } - - private static OutputFormatVertex createOutput(JobGraph jobGraph, String resultPath, int numSubTasks, TypeSerializerFactory<?> serializer) { - - OutputFormatVertex output = JobGraphUtils.createFileOutput(jobGraph, "Output", numSubTasks); - - { - TaskConfig taskConfig = new TaskConfig(output.getConfiguration()); - taskConfig.addInputToGroup(0); - taskConfig.setInputSerializer(serializer, 0); - - PointOutFormat outFormat = new PointOutFormat(); - outFormat.setOutputFilePath(new Path(resultPath)); - - taskConfig.setStubWrapper(new UserCodeObjectWrapper<PointOutFormat>(outFormat)); - } - - return output; - } - - private static JobVertex createIterationHead(JobGraph jobGraph, int numSubTasks, TypeSerializerFactory<?> serializer) { - JobVertex head = JobGraphUtils.createTask(IterationHeadPactTask.class, "Iteration Head", jobGraph, numSubTasks); - - TaskConfig headConfig = new TaskConfig(head.getConfiguration()); - headConfig.setIterationId(ITERATION_ID); - - // initial input / partial solution - headConfig.addInputToGroup(0); - headConfig.setIterationHeadPartialSolutionOrWorksetInputIndex(0); - headConfig.setInputSerializer(serializer, 0); - - // back channel / iterations - headConfig.setRelativeBackChannelMemory(MEMORY_FRACTION_PER_CONSUMER); - - // output into iteration. broadcasting the centers - headConfig.setOutputSerializer(serializer); - headConfig.addOutputShipStrategy(ShipStrategyType.BROADCAST); - - // final output - TaskConfig headFinalOutConfig = new TaskConfig(new Configuration()); - headFinalOutConfig.setOutputSerializer(serializer); - headFinalOutConfig.addOutputShipStrategy(ShipStrategyType.FORWARD); - headConfig.setIterationHeadFinalOutputConfig(headFinalOutConfig); - - // the sync - headConfig.setIterationHeadIndexOfSyncOutput(2); - - // the driver - headConfig.setDriver(NoOpDriver.class); - headConfig.setDriverStrategy(DriverStrategy.UNARY_NO_OP); - - return head; - } - - private static JobVertex createMapper(JobGraph jobGraph, int numSubTasks, TypeSerializerFactory<?> inputSerializer, - TypeSerializerFactory<?> broadcastVarSerializer, TypeSerializerFactory<?> outputSerializer, - TypeComparatorFactory<?> outputComparator) - { - JobVertex mapper = JobGraphUtils.createTask(IterationIntermediatePactTask.class, - "Map (Select nearest center)", jobGraph, numSubTasks); - - TaskConfig intermediateConfig = new TaskConfig(mapper.getConfiguration()); - intermediateConfig.setIterationId(ITERATION_ID); - - intermediateConfig.setDriver(CollectorMapDriver.class); - intermediateConfig.setDriverStrategy(DriverStrategy.COLLECTOR_MAP); - intermediateConfig.addInputToGroup(0); - intermediateConfig.setInputSerializer(inputSerializer, 0); - - intermediateConfig.setOutputSerializer(outputSerializer); - intermediateConfig.addOutputShipStrategy(ShipStrategyType.PARTITION_HASH); - intermediateConfig.setOutputComparator(outputComparator, 0); - - intermediateConfig.setBroadcastInputName("centers", 0); - intermediateConfig.addBroadcastInputToGroup(0); - intermediateConfig.setBroadcastInputSerializer(broadcastVarSerializer, 0); - - // the udf - intermediateConfig.setStubWrapper(new UserCodeObjectWrapper<SelectNearestCenter>(new SelectNearestCenter())); - - return mapper; - } - - private static JobVertex createReducer(JobGraph jobGraph, int numSubTasks, TypeSerializerFactory<?> inputSerializer, - TypeComparatorFactory<?> inputComparator, TypeSerializerFactory<?> outputSerializer) - { - // ---------------- the tail (reduce) -------------------- - - JobVertex tail = JobGraphUtils.createTask(IterationTailPactTask.class, "Reduce / Iteration Tail", jobGraph, - numSubTasks); - - TaskConfig tailConfig = new TaskConfig(tail.getConfiguration()); - tailConfig.setIterationId(ITERATION_ID); - tailConfig.setIsWorksetUpdate(); - - // inputs and driver - tailConfig.setDriver(GroupReduceDriver.class); - tailConfig.setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE); - tailConfig.addInputToGroup(0); - tailConfig.setInputSerializer(inputSerializer, 0); - tailConfig.setDriverComparator(inputComparator, 0); - - tailConfig.setInputLocalStrategy(0, LocalStrategy.SORT); - tailConfig.setInputComparator(inputComparator, 0); - tailConfig.setRelativeMemoryInput(0, MEMORY_FRACTION_PER_CONSUMER); - tailConfig.setFilehandlesInput(0, 128); - tailConfig.setSpillingThresholdInput(0, 0.9f); - - // output - tailConfig.setOutputSerializer(outputSerializer); - - // the udf - tailConfig.setStubWrapper(new UserCodeObjectWrapper<WrappingReduceFunction>(new WrappingReduceFunction(new RecomputeClusterCenter()))); - - return tail; - } - - private static JobVertex createSync(JobGraph jobGraph, int numIterations, int parallelism) { - JobVertex sync = JobGraphUtils.createSync(jobGraph, parallelism); - TaskConfig syncConfig = new TaskConfig(sync.getConfiguration()); - syncConfig.setNumberOfIterations(numIterations); - syncConfig.setIterationId(ITERATION_ID); - return sync; - } - - // ------------------------------------------------------------------------------------------------------------- - // Unified solution set and workset tail update - // ------------------------------------------------------------------------------------------------------------- - - private static JobGraph createJobGraph(String pointsPath, String centersPath, String resultPath, int numSubTasks, int numIterations) { - - // -- init ------------------------------------------------------------------------------------------------- - final TypeSerializerFactory<?> serializer = RecordSerializerFactory.get(); - @SuppressWarnings("unchecked") - final TypeComparatorFactory<?> int0Comparator = new RecordComparatorFactory(new int[] { 0 }, new Class[] { IntValue.class }); - - JobGraph jobGraph = new JobGraph("KMeans Iterative"); - - // -- vertices --------------------------------------------------------------------------------------------- - InputFormatVertex points = createPointsInput(jobGraph, pointsPath, numSubTasks, serializer); - InputFormatVertex centers = createCentersInput(jobGraph, centersPath, numSubTasks, serializer); - - JobVertex head = createIterationHead(jobGraph, numSubTasks, serializer); - JobVertex mapper = createMapper(jobGraph, numSubTasks, serializer, serializer, serializer, int0Comparator); - - JobVertex reducer = createReducer(jobGraph, numSubTasks, serializer, int0Comparator, serializer); - - JobVertex sync = createSync(jobGraph, numIterations, numSubTasks); - - OutputFormatVertex output = createOutput(jobGraph, resultPath, numSubTasks, serializer); - - // -- edges ------------------------------------------------------------------------------------------------ - JobGraphUtils.connect(points, mapper, DistributionPattern.POINTWISE); - - JobGraphUtils.connect(centers, head, DistributionPattern.POINTWISE); - - JobGraphUtils.connect(head, mapper, DistributionPattern.ALL_TO_ALL); - new TaskConfig(mapper.getConfiguration()).setBroadcastGateIterativeWithNumberOfEventsUntilInterrupt(0, numSubTasks); - new TaskConfig(mapper.getConfiguration()).setInputCached(0, true); - new TaskConfig(mapper.getConfiguration()).setRelativeInputMaterializationMemory(0, - MEMORY_FRACTION_PER_CONSUMER); - - JobGraphUtils.connect(mapper, reducer, DistributionPattern.ALL_TO_ALL); - new TaskConfig(reducer.getConfiguration()).setGateIterativeWithNumberOfEventsUntilInterrupt(0, numSubTasks); - - JobGraphUtils.connect(head, output, DistributionPattern.POINTWISE); - - JobGraphUtils.connect(head, sync, DistributionPattern.ALL_TO_ALL); - - // -- instance sharing ------------------------------------------------------------------------------------- - - SlotSharingGroup sharingGroup = new SlotSharingGroup(); - - points.setSlotSharingGroup(sharingGroup); - centers.setSlotSharingGroup(sharingGroup); - head.setSlotSharingGroup(sharingGroup); - mapper.setSlotSharingGroup(sharingGroup); - reducer.setSlotSharingGroup(sharingGroup); - sync.setSlotSharingGroup(sharingGroup); - output.setSlotSharingGroup(sharingGroup); - - mapper.setStrictlyCoLocatedWith(head); - reducer.setStrictlyCoLocatedWith(head); - - return jobGraph; - } -}
