johnyangk closed pull request #72: [NEMO-59] Skewed data-aware executor allocation URL: https://github.com/apache/incubator-nemo/pull/72
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/common/src/main/java/edu/snu/nemo/common/ir/vertex/executionproperty/SkewnessAwareSchedulingProperty.java b/common/src/main/java/edu/snu/nemo/common/ir/vertex/executionproperty/SkewnessAwareSchedulingProperty.java new file mode 100644 index 000000000..9ec336471 --- /dev/null +++ b/common/src/main/java/edu/snu/nemo/common/ir/vertex/executionproperty/SkewnessAwareSchedulingProperty.java @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2018 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.nemo.common.ir.vertex.executionproperty; + +import edu.snu.nemo.common.ir.executionproperty.VertexExecutionProperty; + +/** + * This property decides whether or not to handle skew when scheduling this vertex. + */ +public final class SkewnessAwareSchedulingProperty extends VertexExecutionProperty<Boolean> { + private static final SkewnessAwareSchedulingProperty HANDLE_SKEW + = new SkewnessAwareSchedulingProperty(true); + private static final SkewnessAwareSchedulingProperty DONT_HANDLE_SKEW + = new SkewnessAwareSchedulingProperty(false); + + /** + * Default constructor. + * + * @param value value of the ExecutionProperty + */ + private SkewnessAwareSchedulingProperty(final boolean value) { + super(value); + } + + /** + * Static method getting execution property. + * + * @param value value of the new execution property + * @return the execution property + */ + public static SkewnessAwareSchedulingProperty of(final boolean value) { + return value ? HANDLE_SKEW : DONT_HANDLE_SKEW; + } +} diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DataSkewVertexPass.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DataSkewVertexPass.java index c9e7aa017..2f543cbbd 100644 --- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DataSkewVertexPass.java +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DataSkewVertexPass.java @@ -20,6 +20,9 @@ import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.common.ir.vertex.MetricCollectionBarrierVertex; import edu.snu.nemo.common.ir.vertex.executionproperty.DynamicOptimizationProperty; +import edu.snu.nemo.common.ir.vertex.executionproperty.SkewnessAwareSchedulingProperty; + +import java.util.List; /** * Pass to annotate the DAG for a job to perform data skew. @@ -33,14 +36,28 @@ public DataSkewVertexPass() { super(DynamicOptimizationProperty.class); } + private boolean hasMetricCollectionBarrierVertexAsParent(final DAG<IRVertex, IREdge> dag, + final IRVertex v) { + List<IRVertex> parents = dag.getParents(v.getId()); + for (IRVertex parent : parents) { + if (parent instanceof MetricCollectionBarrierVertex) { + return true; + } + } + return false; + } + @Override public DAG<IRVertex, IREdge> apply(final DAG<IRVertex, IREdge> dag) { - dag.topologicalDo(v -> { - // we only care about metric collection barrier vertices. - if (v instanceof MetricCollectionBarrierVertex) { - v.setProperty(DynamicOptimizationProperty.of(DynamicOptimizationProperty.Value.DataSkewRuntimePass)); - } - }); + dag.getVertices().stream() + .filter(v -> v instanceof MetricCollectionBarrierVertex) + .forEach(v -> v.setProperty(DynamicOptimizationProperty + .of(DynamicOptimizationProperty.Value.DataSkewRuntimePass))); + dag.getVertices().stream() + .filter(v -> hasMetricCollectionBarrierVertexAsParent(dag, v) + && !v.getExecutionProperties().containsKey(SkewnessAwareSchedulingProperty.class)) + .forEach(v -> v.getExecutionProperties().put(SkewnessAwareSchedulingProperty.of(true))); + return dag; } } diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DataSkewPolicy.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DataSkewPolicy.java index c2f1f4a35..4226bc65e 100644 --- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DataSkewPolicy.java +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DataSkewPolicy.java @@ -41,6 +41,14 @@ public DataSkewPolicy() { .build(); } + public DataSkewPolicy(final int skewness) { + this.policy = new PolicyBuilder(true) + .registerRuntimePass(new DataSkewRuntimePass().setNumSkewedKeys(skewness), new DataSkewCompositePass()) + .registerCompileTimePass(new LoopOptimizationCompositePass()) + .registerCompileTimePass(new PrimitiveCompositePass()) + .build(); + } + @Override public List<CompileTimePass> getCompileTimePasses() { return this.policy.getCompileTimePasses(); diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/data/HashRange.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/data/HashRange.java index ae1ca3e20..50e43349e 100644 --- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/data/HashRange.java +++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/data/HashRange.java @@ -19,22 +19,23 @@ * Descriptor for hash range. */ public final class HashRange implements KeyRange<Integer> { - private static final HashRange ALL = new HashRange(0, Integer.MAX_VALUE); - + private static final HashRange ALL = new HashRange(0, Integer.MAX_VALUE, false); private final int rangeBeginInclusive; private final int rangeEndExclusive; + private boolean isSkewed; /** * Private constructor. * @param rangeBeginInclusive point at which the hash range starts (inclusive). * @param rangeEndExclusive point at which the hash range ends (exclusive). */ - private HashRange(final int rangeBeginInclusive, final int rangeEndExclusive) { + private HashRange(final int rangeBeginInclusive, final int rangeEndExclusive, final boolean isSkewed) { if (rangeBeginInclusive < 0 || rangeEndExclusive < 0) { throw new RuntimeException("Each boundary value of the range have to be non-negative."); } this.rangeBeginInclusive = rangeBeginInclusive; this.rangeEndExclusive = rangeEndExclusive; + this.isSkewed = isSkewed; } /** @@ -49,8 +50,8 @@ public static HashRange all() { * @param rangeEndExclusive the end of the range (exclusive) * @return A hash range descriptor representing [{@code rangeBeginInclusive}, {@code rangeEndExclusive}) */ - public static HashRange of(final int rangeStartInclusive, final int rangeEndExclusive) { - return new HashRange(rangeStartInclusive, rangeEndExclusive); + public static HashRange of(final int rangeStartInclusive, final int rangeEndExclusive, final boolean isSkewed) { + return new HashRange(rangeStartInclusive, rangeEndExclusive, isSkewed); } /** @@ -120,4 +121,11 @@ public int hashCode() { result = 31 * result + rangeEndExclusive; return result; } + + public void setAsSkewed() { + isSkewed = true; + } + public boolean isSkewed() { + return isSkewed; + } } diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/RuntimeOptimizer.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/RuntimeOptimizer.java index 5d3330498..888e6f58c 100644 --- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/RuntimeOptimizer.java +++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/RuntimeOptimizer.java @@ -47,8 +47,7 @@ public static synchronized PhysicalPlan dynamicOptimization( switch (dynamicOptimizationType) { case DataSkewRuntimePass: - // Metric data for DataSkewRuntimePass is - // a pair of blockIds and map of hashrange, partition size. + // Metric data for DataSkewRuntimePass is a pair of blockIds and map of hashrange, partition size. final Pair<List<String>, Map<Integer, Long>> metricData = Pair.of(metricCollectionBarrierVertex.getBlockIds(), (Map<Integer, Long>) metricCollectionBarrierVertex.getMetricData()); diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java index d8999b4d5..cf8f5b6e8 100644 --- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java +++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java @@ -34,15 +34,19 @@ import java.util.*; import java.util.stream.Collectors; -import java.util.stream.IntStream; /** * Dynamic optimization pass for handling data skew. - * It receives pairs of the key index and the size of a partition for each output block. + * Using a map of key to partition size as a metric used for dynamic optimization, + * this RuntimePass identifies a number of keys with big partition sizes(skewed key) + * and evenly redistributes data via overwriting incoming edges of destination tasks. */ public final class DataSkewRuntimePass implements RuntimePass<Pair<List<String>, Map<Integer, Long>>> { private static final Logger LOG = LoggerFactory.getLogger(DataSkewRuntimePass.class.getName()); private final Set<Class<? extends RuntimeEventHandler>> eventHandlers; + // Skewed keys denote for top n keys in terms of partition size. + private static final int DEFAULT_NUM_SKEWED_KEYS = 3; + private int numSkewedKeys = DEFAULT_NUM_SKEWED_KEYS; /** * Constructor. @@ -52,6 +56,11 @@ public DataSkewRuntimePass() { DynamicOptimizationEventHandler.class); } + public DataSkewRuntimePass setNumSkewedKeys(final int numOfSkewedKeys) { + numSkewedKeys = numOfSkewedKeys; + return this; + } + @Override public Set<Class<? extends RuntimeEventHandler>> getEventHandlerClasses() { return this.eventHandlers; @@ -75,74 +84,117 @@ public PhysicalPlan apply(final PhysicalPlan originalPlan, .collect(Collectors.toList()); // Get number of evaluators of the next stage (number of blocks). - final Integer taskListSize = optimizationEdges.stream().findFirst().orElseThrow(() -> + final Integer numOfDstTasks = optimizationEdges.stream().findFirst().orElseThrow(() -> new RuntimeException("optimization edges are empty")).getDst().getTaskIds().size(); // Calculate keyRanges. - final List<KeyRange> keyRanges = calculateHashRanges(metricData.right(), taskListSize); + final List<KeyRange> keyRanges = calculateKeyRanges(metricData.right(), numOfDstTasks); - // Overwrite the previously assigned hash value range in the physical DAG with the new range. + // Overwrite the previously assigned key range in the physical DAG with the new range. optimizationEdges.forEach(optimizationEdge -> { // Update the information. - final List<KeyRange> taskIdxToHashRange = new ArrayList<>(); - IntStream.range(0, taskListSize).forEach(i -> taskIdxToHashRange.add(keyRanges.get(i))); + final Map<Integer, KeyRange> taskIdxToHashRange = new HashMap<>(); + for (int taskIdx = 0; taskIdx < numOfDstTasks; taskIdx++) { + taskIdxToHashRange.put(taskIdx, keyRanges.get(taskIdx)); + } optimizationEdge.setTaskIdxToKeyRange(taskIdxToHashRange); }); return new PhysicalPlan(originalPlan.getId(), physicalDAGBuilder.build()); } + public List<Integer> identifySkewedKeys(final Map<Integer, Long> keyValToPartitionSizeMap) { + // Identify skewed keyes. + List<Map.Entry<Integer, Long>> sortedMetricData = keyValToPartitionSizeMap.entrySet().stream() + .sorted((e1, e2) -> e2.getValue().compareTo(e1.getValue())) + .collect(Collectors.toList()); + List<Integer> skewedKeys = new ArrayList<>(); + for (int i = 0; i < numSkewedKeys; i++) { + skewedKeys.add(sortedMetricData.get(i).getKey()); + LOG.info("Skewed key: Key {} Size {}", sortedMetricData.get(i).getKey(), sortedMetricData.get(i).getValue()); + } + + return skewedKeys; + } + + private boolean containsSkewedKey(final List<Integer> skewedKeys, + final int startingKey, final int finishingKey) { + for (int k = startingKey; k < finishingKey; k++) { + if (skewedKeys.contains(k)) { + return true; + } + } + return false; + } + /** - * Method for calculating key ranges to evenly distribute the skewed metric data. + * Evenly distribute the skewed data to the destination tasks. + * Partition denotes for a keyed portion of a Task output, whose key is a key. + * Using a map of key to partition size, this method groups the given partitions + * to a key range of partitions with approximate size of (total size of partitions / the number of tasks). * - * @param aggregatedMetricData the metric data. - * @param taskListSize the size of the task list. + * @param keyToPartitionSizeMap a map of key to partition size. + * @param numOfDstTasks the number of tasks that receives this data as input. * @return the list of key ranges calculated. */ @VisibleForTesting - public List<KeyRange> calculateHashRanges(final Map<Integer, Long> aggregatedMetricData, - final Integer taskListSize) { - // NOTE: aggregatedMetricDataMap is made up of a map of (hash value, blockSize). - // Get the max hash value. - final int maxHashValue = aggregatedMetricData.keySet().stream() + public List<KeyRange> calculateKeyRanges(final Map<Integer, Long> keyToPartitionSizeMap, + final Integer numOfDstTasks) { + // Get the biggest key. + final int maxKey = keyToPartitionSizeMap.keySet().stream() .max(Integer::compareTo) - .orElseThrow(() -> new DynamicOptimizationException("Cannot find max hash value among blocks.")); - - // Do the optimization using the information derived above. - final Long totalSize = aggregatedMetricData.values().stream().mapToLong(n -> n).sum(); // get total size - final Long idealSizePerTask = totalSize / taskListSize; // and derive the ideal size per task - LOG.info("idealSizePerTask {} = {}(totalSize) / {}(taskListSize)", - idealSizePerTask, totalSize, taskListSize); - - // find HashRanges to apply (for each blocks of each block). - final List<KeyRange> keyRanges = new ArrayList<>(taskListSize); - int startingHashValue = 0; - int finishingHashValue = 1; // initial values - Long currentAccumulatedSize = aggregatedMetricData.getOrDefault(startingHashValue, 0L); - for (int i = 1; i <= taskListSize; i++) { - if (i != taskListSize) { - final Long idealAccumulatedSize = idealSizePerTask * i; // where we should end - // find the point while adding up one by one. + .orElseThrow(() -> new DynamicOptimizationException("Cannot find max key among blocks.")); + + // Identify skewed keys, which is top numSkewedKeys number of keys. + List<Integer> skewedKeys = identifySkewedKeys(keyToPartitionSizeMap); + + // Calculate the ideal size for each destination task. + final Long totalSize = keyToPartitionSizeMap.values().stream().mapToLong(n -> n).sum(); // get total size + final Long idealSizePerTask = totalSize / numOfDstTasks; // and derive the ideal size per task + + final List<KeyRange> keyRanges = new ArrayList<>(numOfDstTasks); + int startingKey = 0; + int finishingKey = 1; + Long currentAccumulatedSize = keyToPartitionSizeMap.getOrDefault(startingKey, 0L); + Long prevAccumulatedSize = 0L; + for (int i = 1; i <= numOfDstTasks; i++) { + if (i != numOfDstTasks) { + // Ideal accumulated partition size for this task. + final Long idealAccumulatedSize = idealSizePerTask * i; + // By adding partition sizes, find the accumulated size nearest to the given ideal size. while (currentAccumulatedSize < idealAccumulatedSize) { - currentAccumulatedSize += aggregatedMetricData.getOrDefault(finishingHashValue, 0L); - finishingHashValue++; + currentAccumulatedSize += keyToPartitionSizeMap.getOrDefault(finishingKey, 0L); + finishingKey++; } final Long oneStepBack = - currentAccumulatedSize - aggregatedMetricData.getOrDefault(finishingHashValue - 1, 0L); + currentAccumulatedSize - keyToPartitionSizeMap.getOrDefault(finishingKey - 1, 0L); final Long diffFromIdeal = currentAccumulatedSize - idealAccumulatedSize; final Long diffFromIdealOneStepBack = idealAccumulatedSize - oneStepBack; // Go one step back if we came too far. if (diffFromIdeal > diffFromIdealOneStepBack) { - finishingHashValue--; - currentAccumulatedSize -= aggregatedMetricData.getOrDefault(finishingHashValue, 0L); + finishingKey--; + currentAccumulatedSize -= keyToPartitionSizeMap.getOrDefault(finishingKey, 0L); } - // assign appropriately - keyRanges.add(i - 1, HashRange.of(startingHashValue, finishingHashValue)); - startingHashValue = finishingHashValue; + boolean isSkewedKey = containsSkewedKey(skewedKeys, startingKey, finishingKey); + keyRanges.add(i - 1, HashRange.of(startingKey, finishingKey, isSkewedKey)); + LOG.debug("KeyRange {}~{}, Size {}", startingKey, finishingKey - 1, + currentAccumulatedSize - prevAccumulatedSize); + + prevAccumulatedSize = currentAccumulatedSize; + startingKey = finishingKey; } else { // last one: we put the range of the rest. - keyRanges.add(i - 1, HashRange.of(startingHashValue, maxHashValue + 1)); + boolean isSkewedKey = containsSkewedKey(skewedKeys, startingKey, finishingKey); + keyRanges.add(i - 1, + HashRange.of(startingKey, maxKey + 1, isSkewedKey)); + + while (finishingKey <= maxKey) { + currentAccumulatedSize += keyToPartitionSizeMap.getOrDefault(finishingKey, 0L); + finishingKey++; + } + LOG.debug("KeyRange {}~{}, Size {}", startingKey, maxKey + 1, + currentAccumulatedSize - prevAccumulatedSize); } } return keyRanges; diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java index 504501bb9..ad62e6856 100644 --- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java +++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java @@ -24,8 +24,8 @@ import edu.snu.nemo.runtime.common.data.KeyRange; import edu.snu.nemo.runtime.common.data.HashRange; -import java.util.ArrayList; -import java.util.List; +import java.util.HashMap; +import java.util.Map; /** * Edge of a stage that connects an IRVertex of the source stage to an IRVertex of the destination stage. @@ -47,7 +47,7 @@ /** * The list between the task idx and key range to read. */ - private List<KeyRange> taskIdxToKeyRange; + private Map<Integer, KeyRange> taskIdxToKeyRange; /** * Value for {@link DataCommunicationPatternProperty}. @@ -82,9 +82,9 @@ public StageEdge(final String runtimeEdgeId, this.srcVertex = srcVertex; this.dstVertex = dstVertex; // Initialize the key range of each dst task. - this.taskIdxToKeyRange = new ArrayList<>(); + this.taskIdxToKeyRange = new HashMap<>(); for (int taskIdx = 0; taskIdx < dstStage.getTaskIds().size(); taskIdx++) { - taskIdxToKeyRange.add(HashRange.of(taskIdx, taskIdx + 1)); + taskIdxToKeyRange.put(taskIdx, HashRange.of(taskIdx, taskIdx + 1, false)); } this.dataCommunicationPatternValue = edgeProperties.get(DataCommunicationPatternProperty.class) .orElseThrow(() -> new RuntimeException(String.format( @@ -122,7 +122,7 @@ public String propertiesToJSON() { /** * @return the list between the task idx and key range to read. */ - public List<KeyRange> getTaskIdxToKeyRange() { + public Map<Integer, KeyRange> getTaskIdxToKeyRange() { return taskIdxToKeyRange; } @@ -131,7 +131,7 @@ public String propertiesToJSON() { * * @param taskIdxToKeyRange the list to set. */ - public void setTaskIdxToKeyRange(final List<KeyRange> taskIdxToKeyRange) { + public void setTaskIdxToKeyRange(final Map<Integer, KeyRange> taskIdxToKeyRange) { this.taskIdxToKeyRange = taskIdxToKeyRange; } diff --git a/runtime/common/src/test/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePassTest.java b/runtime/common/src/test/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePassTest.java index cd2fc169a..319b8babe 100644 --- a/runtime/common/src/test/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePassTest.java +++ b/runtime/common/src/test/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePassTest.java @@ -15,6 +15,7 @@ */ package edu.snu.nemo.runtime.common.optimizer.pass.runtime; +import edu.snu.nemo.runtime.common.data.HashRange; import edu.snu.nemo.runtime.common.data.KeyRange; import org.junit.Before; import org.junit.Test; @@ -31,32 +32,41 @@ @Before public void setUp() { - // Sum is 30 for each hashRanges: 0-3, 3-5, 5-7, 7-9, 9-10. - buildPartitionSizeList(Arrays.asList(1L, 2L, 4L, 2L, 1L, 8L, 2L, 4L, 2L, 10L)); - buildPartitionSizeList(Arrays.asList(3L, 5L, 5L, 7L, 10L, 3L, 5L, 4L, 8L, 5L)); - buildPartitionSizeList(Arrays.asList(2L, 3L, 5L, 5L, 5L, 6L, 6L, 8L, 4L, 15L)); + // Skewed partition size lists + buildPartitionSizeList(Arrays.asList(5L, 5L, 10L, 50L, 100L)); + buildPartitionSizeList(Arrays.asList(5L, 10L, 5L, 0L, 0L)); + buildPartitionSizeList(Arrays.asList(10L, 5L, 5L, 0L, 0L)); } /** - * Test if the test case suggested above works correctly. + * Test DataSkewRuntimePass whether it redistributes skewed partitions + * to partitions with approximate size of (total size / the number of tasks). */ @Test public void testDataSkewDynamicOptimizationPass() { - final Integer taskListSize = 5; + final Integer taskNum = 5; final List<KeyRange> keyRanges = - new DataSkewRuntimePass().calculateHashRanges(testMetricData, taskListSize); + new DataSkewRuntimePass().setNumSkewedKeys(2).calculateKeyRanges(testMetricData, taskNum); + // Test whether it correctly redistributed hash ranges. assertEquals(0, keyRanges.get(0).rangeBeginInclusive()); - assertEquals(3, keyRanges.get(0).rangeEndExclusive()); - assertEquals(3, keyRanges.get(1).rangeBeginInclusive()); - assertEquals(5, keyRanges.get(1).rangeEndExclusive()); - assertEquals(5, keyRanges.get(2).rangeBeginInclusive()); - assertEquals(7, keyRanges.get(2).rangeEndExclusive()); - assertEquals(7, keyRanges.get(3).rangeBeginInclusive()); - assertEquals(9, keyRanges.get(3).rangeEndExclusive()); - assertEquals(9, keyRanges.get(4).rangeBeginInclusive()); - assertEquals(10, keyRanges.get(4).rangeEndExclusive()); + assertEquals(2, keyRanges.get(0).rangeEndExclusive()); + assertEquals(2, keyRanges.get(1).rangeBeginInclusive()); + assertEquals(3, keyRanges.get(1).rangeEndExclusive()); + assertEquals(3, keyRanges.get(2).rangeBeginInclusive()); + assertEquals(4, keyRanges.get(2).rangeEndExclusive()); + assertEquals(4, keyRanges.get(3).rangeBeginInclusive()); + assertEquals(5, keyRanges.get(3).rangeEndExclusive()); + assertEquals(5, keyRanges.get(4).rangeBeginInclusive()); + assertEquals(5, keyRanges.get(4).rangeEndExclusive()); + + // Test whether it caught the provided skewness. + assertEquals(false, ((HashRange)keyRanges.get(0)).isSkewed()); + assertEquals(false, ((HashRange)keyRanges.get(1)).isSkewed()); + assertEquals(true, ((HashRange)keyRanges.get(2)).isSkewed()); + assertEquals(true, ((HashRange)keyRanges.get(3)).isSkewed()); + assertEquals(false, ((HashRange)keyRanges.get(4)).isSkewed()); } /** @@ -68,7 +78,11 @@ public void testDataSkewDynamicOptimizationPass() { private void buildPartitionSizeList(final List<Long> partitionSizes) { int key = 0; for (final long partitionSize : partitionSizes) { - testMetricData.put(key, partitionSize); + if (testMetricData.containsKey(key)) { + testMetricData.compute(key, (existingKey, existingValue) -> existingValue + partitionSize); + } else { + testMetricData.put(key, partitionSize); + } key++; } } diff --git a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/data/BlockStoreTest.java b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/data/BlockStoreTest.java index b9f0c63fa..2af771ef1 100644 --- a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/data/BlockStoreTest.java +++ b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/data/BlockStoreTest.java @@ -189,9 +189,11 @@ public void setUp() throws Exception { // Generates the range of hash value to read for each read task. final int smallDataRangeEnd = 1 + NUM_READ_HASH_TASKS - NUM_WRITE_HASH_TASKS; - readKeyRangeList.add(HashRange.of(0, smallDataRangeEnd)); + readKeyRangeList.add(HashRange.of(0, smallDataRangeEnd, false)); IntStream.range(0, NUM_READ_HASH_TASKS - 1).forEach(readTaskIdx -> { - readKeyRangeList.add(HashRange.of(smallDataRangeEnd + readTaskIdx, smallDataRangeEnd + readTaskIdx + 1)); + readKeyRangeList.add(HashRange.of(smallDataRangeEnd + readTaskIdx, + smallDataRangeEnd + readTaskIdx + 1, + false)); }); // Generates the expected result of hash range retrieval for each read task. @@ -347,7 +349,8 @@ public Boolean call() { public Boolean call() { try { for (int writeTaskIdx = 0; writeTaskIdx < NUM_WRITE_VERTICES; writeTaskIdx++) { - readResultCheck(blockIdList.get(writeTaskIdx), HashRange.of(readTaskIdx, readTaskIdx + 1), + readResultCheck(blockIdList.get(writeTaskIdx), + HashRange.of(readTaskIdx, readTaskIdx + 1, false), readerSideStore, partitionsPerBlock.get(writeTaskIdx).get(readTaskIdx).getData()); } return true; diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java index 98928a397..43e46ccf8 100644 --- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java +++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java @@ -76,7 +76,6 @@ private static final int REST_SERVER_PORT = 10101; private final ExecutorService runtimeMasterThread; - private final Scheduler scheduler; private final ContainerManager containerManager; private final BlockManagerMaster blockManagerMaster; @@ -84,22 +83,18 @@ private final MessageEnvironment masterMessageEnvironment; private final MetricStore metricStore; private final Map<Integer, Long> aggregatedMetricData; + private final ExecutorService metricAggregationService; private final ClientRPC clientRPC; private final MetricManagerMaster metricManagerMaster; - // For converting json data. This is a thread safe. private final ObjectMapper objectMapper; - private final String dagDirectory; private final Set<IRVertex> irVertices; - private final AtomicInteger resourceRequestCount; - private CountDownLatch metricCountDownLatch; // REST API server for web metric visualization ui. private final Server metricServer; - @Inject public RuntimeMaster(final Scheduler scheduler, final ContainerManager containerManager, @@ -127,7 +122,8 @@ public RuntimeMaster(final Scheduler scheduler, this.irVertices = new HashSet<>(); this.resourceRequestCount = new AtomicInteger(0); this.objectMapper = new ObjectMapper(); - this.aggregatedMetricData = new HashMap<>(); + this.aggregatedMetricData = new ConcurrentHashMap<>(); + this.metricAggregationService = Executors.newFixedThreadPool(10); this.metricStore = MetricStore.getStore(); this.metricServer = startRestMetricServer(); } @@ -155,7 +151,8 @@ private Server startRestMetricServer() { /** * Submits the {@link PhysicalPlan} to Runtime. - * @param plan to execute. + * + * @param plan to execute * @param maxScheduleAttempt the max number of times this plan/sub-part of the plan should be attempted. */ public Pair<JobStateManager, ScheduledExecutorService> execute(final PhysicalPlan plan, @@ -172,7 +169,6 @@ private Server startRestMetricServer() { throw new RuntimeException(e); } }; - try { return runtimeMasterThread.submit(jobExecutionCallable).get(); } catch (Exception e) { @@ -191,9 +187,7 @@ public void terminate() { } catch (final InterruptedException e) { LOG.warn("Waiting executor terminating process interrupted."); } - runtimeMasterThread.execute(() -> { - scheduler.terminate(); try { masterMessageEnvironment.close(); @@ -248,22 +242,22 @@ public void requestContainer(final String resourceSpecificationString) { /** * Called when a container is allocated for this runtime. * A wrapper function for {@link ContainerManager}. - * @param executorId to use for the executor to be launched on this container. - * @param allocatedEvaluator to be used as the container. + * + * @param executorId to use for the executor to be launched on this container. + * @param allocatedEvaluator to be used as the container. * @param executorConfiguration to use for the executor to be launched on this container. */ public void onContainerAllocated(final String executorId, final AllocatedEvaluator allocatedEvaluator, final Configuration executorConfiguration) { runtimeMasterThread.execute(() -> { - containerManager.onContainerAllocated(executorId, allocatedEvaluator, executorConfiguration); - }); } /** * Called when an executor is launched on a container for this runtime. + * * @param activeContext of the launched executor. * @return true if all requested executors have been launched, false otherwise. */ @@ -289,6 +283,7 @@ public boolean onExecutorLaunched(final ActiveContext activeContext) { /** * Called when an executor fails due to container failure on this runtime. + * * @param failedEvaluator that failed. */ public void onExecutorFailed(final FailedEvaluator failedEvaluator) { @@ -314,9 +309,7 @@ public void onExecutorFailed(final FailedEvaluator failedEvaluator) { @Override public void onMessage(final ControlMessage.Message message) { runtimeMasterThread.execute(() -> { - handleControlMessage(message); - }); } @@ -379,7 +372,6 @@ private void handleControlMessage(final ControlMessage.Message message) { } } - /** * Accumulates the metric data for a barrier vertex. * TODO #96: Modularize DataSkewPolicy to use MetricVertex and BarrierVertex. @@ -396,22 +388,24 @@ private void accumulateBarrierMetric(final List<ControlMessage.PartitionSizeEntr .filter(irVertex -> irVertex.getId().equals(srcVertexId)).findFirst() .orElseThrow(() -> new RuntimeException(srcVertexId + " doesn't exist in the submitted Physical Plan")); - // For each hash range index, aggregate the metric data as they arrive. - partitionSizeInfo.forEach(partitionSizeEntry -> { - final int key = partitionSizeEntry.getKey(); - final long size = partitionSizeEntry.getSize(); - if (aggregatedMetricData.containsKey(key)) { - aggregatedMetricData.compute(key, (existKey, existValue) -> existValue + size); - } else { - aggregatedMetricData.put(key, size); - } - }); - if (vertexToSendMetricDataTo instanceof MetricCollectionBarrierVertex) { final MetricCollectionBarrierVertex<Integer, Long> metricCollectionBarrierVertex = (MetricCollectionBarrierVertex) vertexToSendMetricDataTo; + metricCollectionBarrierVertex.addBlockId(blockId); - metricCollectionBarrierVertex.setMetricData(aggregatedMetricData); + metricAggregationService.submit(() -> { + // For each hash range index, we aggregate the metric data. + partitionSizeInfo.forEach(partitionSizeEntry -> { + final int key = partitionSizeEntry.getKey(); + final long size = partitionSizeEntry.getSize(); + if (aggregatedMetricData.containsKey(key)) { + aggregatedMetricData.compute(key, (existKey, existValue) -> existValue + size); + } else { + aggregatedMetricData.put(key, size); + } + }); + metricCollectionBarrierVertex.setMetricData(aggregatedMetricData); + }); } else { throw new RuntimeException("Something wrong happened at DataSkewCompositePass."); } diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/resource/ExecutorRepresenter.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/resource/ExecutorRepresenter.java index cca038297..e43107403 100644 --- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/resource/ExecutorRepresenter.java +++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/resource/ExecutorRepresenter.java @@ -25,11 +25,9 @@ import org.apache.reef.driver.context.ActiveContext; import javax.annotation.concurrent.NotThreadSafe; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; +import java.util.*; import java.util.concurrent.ExecutorService; +import java.util.stream.Collectors; /** * (WARNING) This class is not thread-safe, and thus should only be accessed through ExecutorRegistry. @@ -47,10 +45,10 @@ private final String executorId; private final ResourceSpecification resourceSpecification; - private final Set<String> runningTasks; - private final Map<String, Integer> runningTaskToAttempt; - private final Set<String> completeTasks; - private final Set<String> failedTasks; + private final Set<Task> runningTasks; + private final Map<Task, Integer> runningTaskToAttempt; + private final Set<Task> completeTasks; + private final Set<Task> failedTasks; private final MessageSender<ControlMessage.Message> messageSender; private final ActiveContext activeContext; private final ExecutorService serializationExecutorService; @@ -88,7 +86,9 @@ public ExecutorRepresenter(final String executorId, */ public Set<String> onExecutorFailed() { failedTasks.addAll(runningTasks); - final Set<String> snapshot = new HashSet<>(runningTasks); + final Set<String> snapshot = runningTasks.stream() + .map(Task::getTaskId) + .collect(Collectors.toSet()); runningTasks.clear(); return snapshot; } @@ -98,9 +98,9 @@ public ExecutorRepresenter(final String executorId, * @param task */ public void onTaskScheduled(final Task task) { - runningTasks.add(task.getTaskId()); - runningTaskToAttempt.put(task.getTaskId(), task.getAttemptIdx()); - failedTasks.remove(task.getTaskId()); + runningTasks.add(task); + runningTaskToAttempt.put(task, task.getAttemptIdx()); + failedTasks.remove(task); serializationExecutorService.submit(new Runnable() { @Override @@ -133,9 +133,13 @@ public void sendControlMessage(final ControlMessage.Message message) { * */ public void onTaskExecutionComplete(final String taskId) { - runningTasks.remove(taskId); - runningTaskToAttempt.remove(taskId); - completeTasks.add(taskId); + Task completedTask = runningTasks.stream() + .filter(task -> task.getTaskId().equals(taskId)).findFirst() + .orElseThrow(() -> new RuntimeException("Completed task not found in its ExecutorRepresenter")); + + runningTasks.remove(completedTask); + runningTaskToAttempt.remove(completedTask); + completeTasks.add(completedTask); } /** @@ -143,9 +147,13 @@ public void onTaskExecutionComplete(final String taskId) { * @param taskId id of the Task */ public void onTaskExecutionFailed(final String taskId) { - runningTasks.remove(taskId); - runningTaskToAttempt.remove(taskId); - failedTasks.add(taskId); + Task failedTask = runningTasks.stream() + .filter(task -> task.getTaskId().equals(taskId)).findFirst() + .orElseThrow(() -> new RuntimeException("Failed task not found in its ExecutorRepresenter")); + + runningTasks.remove(failedTask); + runningTaskToAttempt.remove(failedTask); + failedTasks.add(failedTask); } /** @@ -158,11 +166,11 @@ public int getExecutorCapacity() { /** * @return set of ids of Tasks that are running in this executor */ - public Set<String> getRunningTasks() { + public Set<Task> getRunningTasks() { return runningTasks; } - public Map<String, Integer> getRunningTaskToAttempt() { + public Map<Task, Integer> getRunningTaskToAttempt() { return runningTaskToAttempt; } @@ -176,7 +184,7 @@ public int getExecutorCapacity() { /** * @return set of ids of Tasks that have been completed in this executor */ - public Set<String> getCompleteTasks() { + public Set<Task> getCompleteTasks() { return completeTasks; } diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/ExecutorRegistry.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/ExecutorRegistry.java index 8f052f0e9..98af0a1b1 100644 --- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/ExecutorRegistry.java +++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/ExecutorRegistry.java @@ -17,6 +17,7 @@ import com.google.common.annotations.VisibleForTesting; import edu.snu.nemo.common.Pair; +import edu.snu.nemo.runtime.common.plan.Task; import edu.snu.nemo.runtime.master.resource.ExecutorRepresenter; import org.apache.reef.annotations.audience.DriverSide; @@ -89,8 +90,10 @@ synchronized void terminate() { @VisibleForTesting synchronized Optional<ExecutorRepresenter> findExecutorForTask(final String taskId) { for (final ExecutorRepresenter executor : getRunningExecutors()) { - if (executor.getRunningTasks().contains(taskId) || executor.getCompleteTasks().contains(taskId)) { - return Optional.of(executor); + for (final Task runningTask : executor.getRunningTasks()) { + if (runningTask.getTaskId().equals(taskId)) { + return Optional.of(executor); + } } } return Optional.empty(); diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/MinOccupancyFirstSchedulingPolicy.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/MinOccupancyFirstSchedulingPolicy.java index e53f659a7..7023961ec 100644 --- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/MinOccupancyFirstSchedulingPolicy.java +++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/MinOccupancyFirstSchedulingPolicy.java @@ -23,8 +23,6 @@ import javax.annotation.concurrent.ThreadSafe; import javax.inject.Inject; import java.util.*; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * This policy chooses a set of Executors, on which have minimum running Tasks. @@ -32,7 +30,6 @@ @ThreadSafe @DriverSide public final class MinOccupancyFirstSchedulingPolicy implements SchedulingPolicy { - private static final Logger LOG = LoggerFactory.getLogger(MinOccupancyFirstSchedulingPolicy.class.getName()); @VisibleForTesting @Inject diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SchedulingConstraintRegistry.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SchedulingConstraintRegistry.java index 639e77579..4a774bf7f 100644 --- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SchedulingConstraintRegistry.java +++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SchedulingConstraintRegistry.java @@ -40,11 +40,13 @@ private SchedulingConstraintRegistry( final ContainerTypeAwareSchedulingConstraint containerTypeAwareSchedulingConstraint, final FreeSlotSchedulingConstraint freeSlotSchedulingConstraint, final SourceLocationAwareSchedulingConstraint sourceLocationAwareSchedulingConstraint, + final SkewnessAwareSchedulingConstraint skewnessAwareSchedulingConstraint, final NodeShareSchedulingConstraint nodeShareSchedulingConstraint) { registerSchedulingConstraint(containerTypeAwareSchedulingConstraint); registerSchedulingConstraint(freeSlotSchedulingConstraint); registerSchedulingConstraint(sourceLocationAwareSchedulingConstraint); registerSchedulingConstraint(nodeShareSchedulingConstraint); + registerSchedulingConstraint(skewnessAwareSchedulingConstraint); } /** diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SkewnessAwareSchedulingConstraint.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SkewnessAwareSchedulingConstraint.java new file mode 100644 index 000000000..6ac14cf23 --- /dev/null +++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SkewnessAwareSchedulingConstraint.java @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2018 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.nemo.runtime.master.scheduler; + +import com.google.common.annotations.VisibleForTesting; +import edu.snu.nemo.common.ir.executionproperty.AssociatedProperty; +import edu.snu.nemo.common.ir.vertex.executionproperty.SkewnessAwareSchedulingProperty; +import edu.snu.nemo.runtime.common.RuntimeIdGenerator; +import edu.snu.nemo.runtime.common.data.HashRange; +import edu.snu.nemo.runtime.common.data.KeyRange; +import edu.snu.nemo.runtime.common.plan.StageEdge; +import edu.snu.nemo.runtime.common.plan.Task; +import edu.snu.nemo.runtime.master.resource.ExecutorRepresenter; +import org.apache.reef.annotations.audience.DriverSide; + +import javax.annotation.concurrent.ThreadSafe; +import javax.inject.Inject; + +/** + * This policy aims to distribute partitions with skewed keys to different executors. + */ +@ThreadSafe +@DriverSide +@AssociatedProperty(SkewnessAwareSchedulingProperty.class) +public final class SkewnessAwareSchedulingConstraint implements SchedulingConstraint { + + @VisibleForTesting + @Inject + public SkewnessAwareSchedulingConstraint() { + } + + public boolean hasSkewedData(final Task task) { + final int taskIdx = RuntimeIdGenerator.getIndexFromTaskId(task.getTaskId()); + for (StageEdge inEdge : task.getTaskIncomingEdges()) { + final KeyRange hashRange = inEdge.getTaskIdxToKeyRange().get(taskIdx); + if (((HashRange) hashRange).isSkewed()) { + return true; + } + } + return false; + } + + @Override + public boolean testSchedulability(final ExecutorRepresenter executor, final Task task) { + // Check if this executor had already received heavy tasks + for (Task runningTask : executor.getRunningTasks()) { + if (hasSkewedData(runningTask) && hasSkewedData(task)) { + return false; + } + } + return true; + } +} diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SourceLocationAwareSchedulingConstraint.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SourceLocationAwareSchedulingConstraint.java index f18c90080..cb74f05ff 100644 --- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SourceLocationAwareSchedulingConstraint.java +++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SourceLocationAwareSchedulingConstraint.java @@ -22,8 +22,6 @@ import edu.snu.nemo.runtime.common.plan.Task; import edu.snu.nemo.runtime.master.resource.ExecutorRepresenter; import org.apache.reef.annotations.audience.DriverSide; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import javax.annotation.concurrent.ThreadSafe; import javax.inject.Inject; @@ -38,7 +36,6 @@ @DriverSide @AssociatedProperty(SourceLocationAwareSchedulingProperty.class) public final class SourceLocationAwareSchedulingConstraint implements SchedulingConstraint { - private static final Logger LOG = LoggerFactory.getLogger(SourceLocationAwareSchedulingConstraint.class); @VisibleForTesting @Inject diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FreeSlotSchedulingConstraintTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FreeSlotSchedulingConstraintTest.java index f2dd87856..7449732cd 100644 --- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FreeSlotSchedulingConstraintTest.java +++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FreeSlotSchedulingConstraintTest.java @@ -37,11 +37,17 @@ @PrepareForTest({ExecutorRepresenter.class, Task.class}) public final class FreeSlotSchedulingConstraintTest { + private static Task mockTask(final String taskId) { + final Task task = mock(Task.class); + when(task.getTaskId()).thenReturn(taskId); + return task; + } + private static ExecutorRepresenter mockExecutorRepresenter(final int numRunningTasks, final int capacity) { final ExecutorRepresenter executorRepresenter = mock(ExecutorRepresenter.class); - final Set<String> runningTasks = new HashSet<>(); - IntStream.range(0, numRunningTasks).forEach(i -> runningTasks.add(String.valueOf(i))); + final Set<Task> runningTasks = new HashSet<>(); + IntStream.range(0, numRunningTasks).forEach(i -> runningTasks.add(mockTask(String.valueOf(i)))); when(executorRepresenter.getRunningTasks()).thenReturn(runningTasks); when(executorRepresenter.getExecutorCapacity()).thenReturn(capacity); return executorRepresenter; diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/MinOccupancyFirstSchedulingPolicyTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/MinOccupancyFirstSchedulingPolicyTest.java index bf6ebc8a7..8831e6c8b 100644 --- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/MinOccupancyFirstSchedulingPolicyTest.java +++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/MinOccupancyFirstSchedulingPolicyTest.java @@ -35,10 +35,16 @@ @PrepareForTest({ExecutorRepresenter.class, Task.class}) public final class MinOccupancyFirstSchedulingPolicyTest { + private static Task mockTask(final String taskId) { + final Task task = mock(Task.class); + when(task.getTaskId()).thenReturn(taskId); + return task; + } + private static ExecutorRepresenter mockExecutorRepresenter(final int numRunningTasks) { final ExecutorRepresenter executorRepresenter = mock(ExecutorRepresenter.class); - final Set<String> runningTasks = new HashSet<>(); - IntStream.range(0, numRunningTasks).forEach(i -> runningTasks.add(String.valueOf(i))); + final Set<Task> runningTasks = new HashSet<>(); + IntStream.range(0, numRunningTasks).forEach(i -> runningTasks.add(mockTask(String.valueOf(i)))); when(executorRepresenter.getRunningTasks()).thenReturn(runningTasks); return executorRepresenter; } diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SkewnessAwareSchedulingConstraintTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SkewnessAwareSchedulingConstraintTest.java new file mode 100644 index 000000000..87c933f4b --- /dev/null +++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SkewnessAwareSchedulingConstraintTest.java @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2018 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.nemo.runtime.master.scheduler; + +import edu.snu.nemo.runtime.common.RuntimeIdGenerator; +import edu.snu.nemo.runtime.common.data.HashRange; +import edu.snu.nemo.runtime.common.data.KeyRange; +import edu.snu.nemo.runtime.common.plan.StageEdge; +import edu.snu.nemo.runtime.common.plan.Task; +import edu.snu.nemo.runtime.master.resource.ExecutorRepresenter; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.*; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Test cases for {@link SkewnessAwareSchedulingConstraint}. + */ +@RunWith(PowerMockRunner.class) +@PrepareForTest({ExecutorRepresenter.class, Task.class, HashRange.class, StageEdge.class}) +public final class SkewnessAwareSchedulingConstraintTest { + + private static StageEdge mockStageEdge() { + final Map<Integer, KeyRange> taskIdxToKeyRange = new HashMap<>(); + + final HashRange skewedHashRange1 = mock(HashRange.class); + when(skewedHashRange1.isSkewed()).thenReturn(true); + final HashRange skewedHashRange2 = mock(HashRange.class); + when(skewedHashRange2.isSkewed()).thenReturn(true); + final HashRange hashRange = mock(HashRange.class); + when(hashRange.isSkewed()).thenReturn(false); + + taskIdxToKeyRange.put(0, skewedHashRange1); + taskIdxToKeyRange.put(1, skewedHashRange2); + taskIdxToKeyRange.put(2, hashRange); + + final StageEdge inEdge = mock(StageEdge.class); + when(inEdge.getTaskIdxToKeyRange()).thenReturn(taskIdxToKeyRange); + + return inEdge; + } + + private static Task mockTask(final int taskIdx, final List<StageEdge> inEdges) { + final Task task = mock(Task.class); + when(task.getTaskId()).thenReturn(RuntimeIdGenerator.generateTaskId(taskIdx, "Stage-0")); + when(task.getTaskIncomingEdges()).thenReturn(inEdges); + return task; + } + + private static ExecutorRepresenter mockExecutorRepresenter(final Task task) { + final ExecutorRepresenter executorRepresenter = mock(ExecutorRepresenter.class); + final Set<Task> runningTasks = new HashSet<>(); + runningTasks.add(task); + when(executorRepresenter.getRunningTasks()).thenReturn(runningTasks); + return executorRepresenter; + } + + /** + * {@link SkewnessAwareSchedulingConstraint} should schedule Tasks assigned with skewed partitions + * to different executors. + */ + @Test + public void testScheduleSkewedTasks() { + final SchedulingConstraint schedulingConstraint = new SkewnessAwareSchedulingConstraint(); + final StageEdge inEdge = mockStageEdge(); + final Task task0 = mockTask(0, Arrays.asList(inEdge)); + final Task task1 = mockTask(1, Arrays.asList(inEdge)); + final Task task2 = mockTask(2, Arrays.asList(inEdge)); + final ExecutorRepresenter e0 = mockExecutorRepresenter(task0); + + assertEquals(true, schedulingConstraint.testSchedulability(e0, task2)); + assertEquals(false, schedulingConstraint.testSchedulability(e0, task1)); + } +} diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SourceLocationAwareSchedulingConstraintTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SourceLocationAwareSchedulingConstraintTest.java index 772c587ed..99abf883a 100644 --- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SourceLocationAwareSchedulingConstraintTest.java +++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SourceLocationAwareSchedulingConstraintTest.java @@ -68,7 +68,8 @@ public void testSourceLocationAwareSchedulingNotAvailable() { } /** - * {@link SourceLocationAwareSchedulingConstraint} should properly schedule TGs with multiple source locations. + * {@link SourceLocationAwareSchedulingConstraint} should properly schedule {@link Task}s + * with multiple source locations. */ @Test public void testSourceLocationAwareSchedulingWithMultiSource() { diff --git a/runtime/plangenerator/src/main/java/edu/snu/nemo/runtime/plangenerator/TestPlanGenerator.java b/runtime/plangenerator/src/main/java/edu/snu/nemo/runtime/plangenerator/TestPlanGenerator.java index 0af8c8d8a..25b26d3bd 100644 --- a/runtime/plangenerator/src/main/java/edu/snu/nemo/runtime/plangenerator/TestPlanGenerator.java +++ b/runtime/plangenerator/src/main/java/edu/snu/nemo/runtime/plangenerator/TestPlanGenerator.java @@ -190,4 +190,3 @@ private static PhysicalPlan convertIRToPhysical(final DAG<IRVertex, IREdge> irDA return dagBuilder.buildWithoutSourceSinkCheck(); } } - diff --git a/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/composite/DataSkewCompositePassTest.java b/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/composite/DataSkewCompositePassTest.java index 7f81154c8..8345b0570 100644 --- a/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/composite/DataSkewCompositePassTest.java +++ b/tests/src/test/java/edu/snu/nemo/tests/compiler/optimizer/pass/compiletime/composite/DataSkewCompositePassTest.java @@ -24,6 +24,7 @@ import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.common.ir.vertex.MetricCollectionBarrierVertex; import edu.snu.nemo.common.ir.executionproperty.ExecutionProperty; +import edu.snu.nemo.common.ir.vertex.executionproperty.SkewnessAwareSchedulingProperty; import edu.snu.nemo.compiler.optimizer.pass.compiletime.annotating.AnnotatingPass; import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.CompositePass; import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.DataSkewCompositePass; @@ -103,5 +104,11 @@ public void testDataSkewPass() throws Exception { .equals(e.getPropertyValue(MetricCollectionProperty.class))) .forEach(e -> assertEquals(PartitionerProperty.Value.DataSkewHashPartitioner, e.getPropertyValue(PartitionerProperty.class).get()))); + + processedDAG.filterVertices(v -> v instanceof MetricCollectionBarrierVertex) + .forEach(metricV -> { + List<IRVertex> reducerV = processedDAG.getChildren(metricV.getId()); + reducerV.forEach(rV -> assertTrue(rV.getPropertyValue(SkewnessAwareSchedulingProperty.class).get())); + }); } } ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services