TEZ-3269. Provide basic fair routing and scheduling functionality via custom VertexManager and EdgeManager.
Project: http://git-wip-us.apache.org/repos/asf/tez/repo Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/2c4ef9fe Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/2c4ef9fe Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/2c4ef9fe Branch: refs/heads/master Commit: 2c4ef9fe58395aa8e835a4c50cb65bbb26428638 Parents: 0d59844 Author: Ming Ma <[email protected]> Authored: Sat Nov 12 08:35:28 2016 -0800 Committer: Ming Ma <[email protected]> Committed: Sat Nov 12 08:35:28 2016 -0800 ---------------------------------------------------------------------- CHANGES.txt | 1 + tez-runtime-library/findbugs-exclude.xml | 42 + tez-runtime-library/pom.xml | 1 + .../DestinationTaskInputsProperty.java | 92 ++ .../vertexmanager/FairEdgeConfiguration.java | 111 ++ .../vertexmanager/FairShuffleEdgeManager.java | 154 ++ .../vertexmanager/FairShuffleVertexManager.java | 631 ++++++++ .../vertexmanager/ShuffleVertexManager.java | 37 +- .../vertexmanager/ShuffleVertexManagerBase.java | 135 +- .../src/main/proto/FairShufflePayloads.proto | 37 + .../TestFairShuffleVertexManager.java | 347 +++++ .../vertexmanager/TestShuffleVertexManager.java | 1424 ++---------------- .../TestShuffleVertexManagerBase.java | 1115 ++++++++++++++ .../TestShuffleVertexManagerUtils.java | 346 +++++ 14 files changed, 3080 insertions(+), 1393 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/CHANGES.txt ---------------------------------------------------------------------- diff --git a/CHANGES.txt b/CHANGES.txt index 8128c7b..0948862 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -7,6 +7,7 @@ INCOMPATIBLE CHANGES ALL CHANGES: + TEZ-3269. Provide basic fair routing and scheduling functionality via custom VertexManager and EdgeManager. TEZ-3534. Differentiate thread names on Fetchers, minor changes to shuffle shutdown code. TEZ-3491. Tez job can hang due to container priority inversion. TEZ-3533. ShuffleScheduler should shutdown threadpool on exit. http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/findbugs-exclude.xml ---------------------------------------------------------------------- diff --git a/tez-runtime-library/findbugs-exclude.xml b/tez-runtime-library/findbugs-exclude.xml index d3b6245..da7a013 100644 --- a/tez-runtime-library/findbugs-exclude.xml +++ b/tez-runtime-library/findbugs-exclude.xml @@ -152,6 +152,7 @@ <Match> <Class name="org.apache.tez.dag.library.vertexmanager.ShuffleVertexManagerBase"/> <Or> + <Field name="bipartiteSources"/> <Field name="numBipartiteSourceTasksCompleted"/> <Field name="totalNumBipartiteSourceTasks"/> <Field name="totalTasksToSchedule"/> @@ -159,4 +160,45 @@ <Bug pattern="IS2_INCONSISTENT_SYNC"/> </Match> + <Match> + <Class name="org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads$FairShuffleEdgeManagerConfigPayloadProto"/> + <Field name="unknownFields"/> + <Bug pattern="SE_BAD_FIELD"/> + </Match> + + <Match> + <Class name="org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads$FairShuffleEdgeManagerDestinationTaskPropProto"/> + <Field name="unknownFields"/> + <Bug pattern="SE_BAD_FIELD"/> + </Match> + + <Match> + <Class name="org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads$RangeProto"/> + <Field name="unknownFields"/> + <Bug pattern="SE_BAD_FIELD"/> + </Match> + + <Match> + <Class name="org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads$FairShuffleEdgeManagerConfigPayloadProto"/> + <Field name="PARSER"/> + <Bug pattern="MS_SHOULD_BE_FINAL"/> + </Match> + + <Match> + <Class name="org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads$FairShuffleEdgeManagerDestinationTaskPropProto"/> + <Field name="PARSER"/> + <Bug pattern="MS_SHOULD_BE_FINAL"/> + </Match> + + <Match> + <Class name="org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads$RangeProto"/> + <Field name="PARSER"/> + <Bug pattern="MS_SHOULD_BE_FINAL"/> + </Match> + + <Match> + <Class name="org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads$RangeProto$Builder"/> + <Method name="maybeForceBuilderInitialization"/> + <Bug pattern="UCF_USELESS_CONTROL_FLOW"/> + </Match> </FindBugsFilter> http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/pom.xml ---------------------------------------------------------------------- diff --git a/tez-runtime-library/pom.xml b/tez-runtime-library/pom.xml index b676933..2ccd65f 100644 --- a/tez-runtime-library/pom.xml +++ b/tez-runtime-library/pom.xml @@ -130,6 +130,7 @@ <includes> <include>ShufflePayloads.proto</include> <include>CartesianProductPayload.proto</include> + <include>FairShufflePayloads.proto</include> </includes> </source> <output>${project.build.directory}/generated-sources/java</output> http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/DestinationTaskInputsProperty.java ---------------------------------------------------------------------- diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/DestinationTaskInputsProperty.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/DestinationTaskInputsProperty.java new file mode 100644 index 0000000..bb23f19 --- /dev/null +++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/DestinationTaskInputsProperty.java @@ -0,0 +1,92 @@ +/** +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.tez.dag.library.vertexmanager; + +// Each destination task fetches data from numOfSourceTasks of consecutive +// source tasks with the first source task index being firstSourceTaskIndex. +// For any source task in that range, each destination task fetches +// numOfPartitions consecutive physical outputs with the first physical output +// index being firstPartitionId. +class DestinationTaskInputsProperty { + private final int firstPartitionId; + private final int numOfPartitions; + private final int firstSourceTaskIndex; + private final int numOfSourceTasks; + public DestinationTaskInputsProperty(int firstPartitionId, + int numOfPartitions, int firstSourceTaskIndex, int numOfSourceTasks) { + this.firstPartitionId = firstPartitionId; + this.numOfPartitions = numOfPartitions; + this.firstSourceTaskIndex = firstSourceTaskIndex; + this.numOfSourceTasks = numOfSourceTasks; + } + public int getFirstPartitionId() { + return firstPartitionId; + } + public int getNumOfPartitions() { + return numOfPartitions; + } + public int getFirstSourceTaskIndex() { + return firstSourceTaskIndex; + } + public int getNumOfSourceTasks() { + return numOfSourceTasks; + } + public boolean isSourceTaskInRange(int sourceTaskIndex) { + return firstSourceTaskIndex <= sourceTaskIndex && + sourceTaskIndex < firstSourceTaskIndex + + numOfSourceTasks; + } + public boolean isPartitionInRange(int partitionId) { + return firstPartitionId <= partitionId && + partitionId < firstPartitionId + numOfPartitions; + } + + // The first physical input index for the source task + public int getFirstPhysicalInputIndex(int sourceTaskIndex) { + return getPhysicalInputIndex(sourceTaskIndex, firstPartitionId); + } + + // The physical input index for the physical output index of the source task + public int getPhysicalInputIndex(int sourceTaskIndex, int partitionId) { + if (isSourceTaskInRange(sourceTaskIndex) && + isPartitionInRange(partitionId)) { + return (sourceTaskIndex - firstSourceTaskIndex) * numOfPartitions + + (partitionId - firstPartitionId); + } else { + return -1; + } + } + + public int getNumOfPhysicalInputs() { + return numOfPartitions * numOfSourceTasks; + } + + public int getSourceTaskIndex(int physicalInputIndex) { + return firstSourceTaskIndex + physicalInputIndex / numOfPartitions; + } + + @Override + public String toString() { + return "firstPartitionId = " + firstPartitionId + + " ,numOfPartitions = " + numOfPartitions + + " ,firstSourceTaskIndex = " + firstSourceTaskIndex + + " ,numOfSourceTasks = " + numOfSourceTasks; + } +} + http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairEdgeConfiguration.java ---------------------------------------------------------------------- diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairEdgeConfiguration.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairEdgeConfiguration.java new file mode 100644 index 0000000..846e0a3 --- /dev/null +++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairEdgeConfiguration.java @@ -0,0 +1,111 @@ +/** +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.tez.dag.library.vertexmanager; + +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; + +import org.apache.tez.dag.api.UserPayload; +import org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads.FairShuffleEdgeManagerConfigPayloadProto; +import org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads.FairShuffleEdgeManagerDestinationTaskPropProto; +import org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads.RangeProto; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map.Entry; + + +/** + * Handles edge configuration serialization and de-serialization between + * {@link FairShuffleVertexManager} and {@link FairShuffleEdgeManager}. + */ +class FairEdgeConfiguration { + private final int numBuckets; + private final HashMap<Integer, DestinationTaskInputsProperty> + destinationInputsProperties; + + public FairEdgeConfiguration(int numBuckets, + HashMap<Integer, DestinationTaskInputsProperty> routingTable) { + this.destinationInputsProperties = routingTable; + this.numBuckets = numBuckets; + } + + private FairShuffleEdgeManagerConfigPayloadProto getConfigPayload() { + FairShuffleEdgeManagerConfigPayloadProto.Builder builder = + FairShuffleEdgeManagerConfigPayloadProto.newBuilder(); + builder.setNumBuckets(numBuckets); + if (destinationInputsProperties != null) { + for (Entry<Integer, DestinationTaskInputsProperty> entry : + destinationInputsProperties.entrySet()) { + FairShuffleEdgeManagerDestinationTaskPropProto.Builder taskBuilder = + FairShuffleEdgeManagerDestinationTaskPropProto.newBuilder(); + taskBuilder. + setDestinationTaskIndex(entry.getKey()). + setPartitions(newRange(entry.getValue().getFirstPartitionId(), + entry.getValue().getNumOfPartitions())). + setSourceTasks(newRange(entry.getValue(). + getFirstSourceTaskIndex(), entry.getValue().getNumOfSourceTasks())); + builder.addDestinationTaskProps(taskBuilder.build()); + } + } + return builder.build(); + } + + private RangeProto newRange(int firstIndex, int numOfIndexes) { + return RangeProto.newBuilder(). + setFirstIndex(firstIndex).setNumOfIndexes(numOfIndexes).build(); + } + + static FairEdgeConfiguration fromUserPayload(UserPayload payload) + throws InvalidProtocolBufferException { + HashMap<Integer, DestinationTaskInputsProperty> routingTable = new HashMap<>(); + FairShuffleEdgeManagerConfigPayloadProto proto = + FairShuffleEdgeManagerConfigPayloadProto.parseFrom( + ByteString.copyFrom(payload.getPayload())); + int numBuckets = proto.getNumBuckets(); + if (proto.getDestinationTaskPropsList() != null) { + for (int i = 0; i < proto.getDestinationTaskPropsList().size(); i++) { + FairShuffleEdgeManagerDestinationTaskPropProto propProto = + proto.getDestinationTaskPropsList().get(i); + routingTable.put( + propProto.getDestinationTaskIndex(), + new DestinationTaskInputsProperty( + propProto.getPartitions().getFirstIndex(), + propProto.getPartitions().getNumOfIndexes(), + propProto.getSourceTasks().getFirstIndex(), + propProto.getSourceTasks().getNumOfIndexes())); + } + } + return new FairEdgeConfiguration(numBuckets, routingTable); + } + + public HashMap<Integer, DestinationTaskInputsProperty> getRoutingTable() { + return destinationInputsProperties; + } + + // The number of partitions used by source vertex. + int getNumBuckets() { + return numBuckets; + } + + UserPayload getBytePayload() { + return UserPayload.create(ByteBuffer.wrap( + getConfigPayload().toByteArray())); + } +} http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairShuffleEdgeManager.java ---------------------------------------------------------------------- diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairShuffleEdgeManager.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairShuffleEdgeManager.java new file mode 100644 index 0000000..ff1c032 --- /dev/null +++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairShuffleEdgeManager.java @@ -0,0 +1,154 @@ +/** +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.tez.dag.library.vertexmanager; + +import com.google.protobuf.InvalidProtocolBufferException; + +import org.apache.tez.dag.api.EdgeManagerPluginContext; +import org.apache.tez.dag.api.EdgeManagerPluginOnDemand; +import org.apache.tez.dag.api.UserPayload; + +import javax.annotation.Nullable; + +import java.util.HashMap; + +/** + * Edge manager for fair routing. Each destination task has its + * DestinationTaskInputsProperty used to decide how to do event routing + * between source and destination. + */ +public class FairShuffleEdgeManager extends EdgeManagerPluginOnDemand { + + private FairEdgeConfiguration conf = null; + // The key in the mapping is the destination task index. + // The value in the mapping is DestinationTaskInputsProperty of the + // destination task. + private HashMap<Integer, DestinationTaskInputsProperty> mapping; + + // used by the framework at runtime. initialize is the real initializer at runtime + public FairShuffleEdgeManager(EdgeManagerPluginContext context) { + super(context); + } + + @Override + public int getNumDestinationTaskPhysicalInputs(int destTaskIndex) { + return mapping.get(destTaskIndex).getNumOfPhysicalInputs(); + } + + @Override + public int getNumSourceTaskPhysicalOutputs(int sourceTaskIndex) { + return conf.getNumBuckets(); + } + + @Override + public int getNumDestinationConsumerTasks(int sourceTaskIndex) { + int numTasks = 0; + for(DestinationTaskInputsProperty entry: mapping.values()) { + if (entry.isSourceTaskInRange(sourceTaskIndex)) { + numTasks++; + } + } + return numTasks; + } + + // called at runtime to initialize the custom edge. + @Override + public void initialize() { + UserPayload userPayload = getContext().getUserPayload(); + if (userPayload == null || userPayload.getPayload() == null || + userPayload.getPayload().limit() == 0) { + throw new RuntimeException("Could not initialize FairShuffleEdgeManager" + + " from provided user payload"); + } + try { + conf = FairEdgeConfiguration.fromUserPayload(userPayload); + mapping = conf.getRoutingTable(); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("Could not initialize FairShuffleEdgeManager" + + " from provided user payload", e); + } + } + + @Override + public int routeInputErrorEventToSource(int destinationTaskIndex, + int destinationFailedInputIndex) { + return mapping.get(destinationTaskIndex).getSourceTaskIndex( + destinationFailedInputIndex); + } + + @Override + public void prepareForRouting() throws Exception { + } + + @Override + public EventRouteMetadata routeDataMovementEventToDestination( + int sourceTaskIndex, int sourceOutputIndex, int destTaskIndex) + throws Exception { + DestinationTaskInputsProperty property = mapping.get(destTaskIndex); + int targetIndex = property.getPhysicalInputIndex(sourceTaskIndex, + sourceOutputIndex); + if (targetIndex != -1) { + return EventRouteMetadata.create(1, new int[]{targetIndex}); + } else { + return null; + } + } + + // Create an array of "count" consecutive integers with starting + // value equal to "startValue". + private int[] getRange(int startValue, int count) { + int[] values = new int[count]; + for (int i = 0; i < count; i++) { + values[i] = startValue + i; + } + return values; + } + + @Override + public @Nullable EventRouteMetadata + routeCompositeDataMovementEventToDestination(int sourceTaskIndex, + int destinationTaskIndex) { + DestinationTaskInputsProperty property = mapping.get(destinationTaskIndex); + int firstPhysicalInputIndex = + property.getFirstPhysicalInputIndex(sourceTaskIndex); + if (firstPhysicalInputIndex >= 0) { + return EventRouteMetadata.create(property.getNumOfPartitions(), + getRange(firstPhysicalInputIndex, property.getNumOfPartitions()), + getRange(property.getFirstPartitionId(), + property.getNumOfPartitions())); + } else { + return null; + } + } + + @Override + public EventRouteMetadata routeInputSourceTaskFailedEventToDestination( + int sourceTaskIndex, int destinationTaskIndex) throws Exception { + DestinationTaskInputsProperty property = mapping.get(destinationTaskIndex); + int firstPhysicalInputIndex = + property.getFirstPhysicalInputIndex(sourceTaskIndex); + if (firstPhysicalInputIndex >= 0) { + return EventRouteMetadata.create(property.getNumOfPartitions(), + getRange(firstPhysicalInputIndex, property.getNumOfPartitions())); + } else { + return null; + } + } +} + http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairShuffleVertexManager.java ---------------------------------------------------------------------- diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairShuffleVertexManager.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairShuffleVertexManager.java new file mode 100644 index 0000000..a8b336c --- /dev/null +++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairShuffleVertexManager.java @@ -0,0 +1,631 @@ +/** +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.tez.dag.library.vertexmanager; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import com.google.common.collect.UnmodifiableIterator; + +import com.google.common.primitives.Ints; +import org.apache.tez.common.TezUtils; +import org.apache.tez.dag.api.EdgeManagerPluginDescriptor; +import org.apache.tez.dag.api.EdgeProperty; +import org.apache.tez.dag.api.TezUncheckedException; +import org.apache.tez.dag.api.VertexManagerPluginContext; +import org.apache.tez.dag.api.VertexManagerPluginDescriptor; +import org.apache.tez.dag.api.VertexManagerPluginContext.ScheduleTaskRequest; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.apache.hadoop.classification.InterfaceAudience.Public; +import org.apache.hadoop.classification.InterfaceStability.Evolving; +import org.apache.hadoop.conf.Configuration; +import org.apache.tez.runtime.api.TaskAttemptIdentifier; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.math.BigInteger; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; + + +/** + * Fair routing based on partition size distribution to achieve optimal + * input size for any destination task thus reduce data skewness. + * By default the feature is turned off and it supports the regular shuffle like + * ShuffleVertexManager. + * When the feature is turned on, there are two routing types as defined in + * {@link FairRoutingType}. One is {@link FairRoutingType#REDUCE_PARALLELISM} + * which is similar to ShuffleVertexManager's auto reduce functionality. + * Another one is {@link FairRoutingType#FAIR_PARALLELISM} where each + * destination task can process a range of consecutive partitions from a range + * of consecutive source tasks. + */ +@Public +@Evolving +public class FairShuffleVertexManager extends ShuffleVertexManagerBase { + + private static final Logger LOG = + LoggerFactory.getLogger(FairShuffleVertexManager.class); + + /** + * The desired size of input per task. Parallelism will be changed to meet + * this criteria. + */ + public static final String TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE = + "tez.fair-shuffle-vertex-manager.desired-task-input-size"; + public static final long + TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE_DEFAULT = 100 * MB; + + /** + * Enables automatic parallelism determination for the vertex. Based on input data + * statistics the parallelism is adjusted to a desired level. + */ + public static final String TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL = + "tez.fair-shuffle-vertex-manager.enable.auto-parallel"; + public static final String + TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL_DEFAULT = + FairRoutingType.NONE.getType(); + + /** + * In case of a ScatterGather connection, the fraction of source tasks which + * should complete before tasks for the current vertex are scheduled + */ + public static final String TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION = + "tez.fair-shuffle-vertex-manager.min-src-fraction"; + public static final float TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION_DEFAULT = 0.25f; + + /** + * In case of a ScatterGather connection, once this fraction of source tasks + * have completed, all tasks on the current vertex can be scheduled. Number of + * tasks ready for scheduling on the current vertex scales linearly between + * min-fraction and max-fraction. Defaults to the greater of the default value + * or tez.fair-shuffle-vertex-manager.min-src-fraction. + */ + public static final String TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION = + "tez.fair-shuffle-vertex-manager.max-src-fraction"; + public static final float TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION_DEFAULT = 0.75f; + + /** + * Enables automatic parallelism determination for the vertex. Based on input data + * statistics the parallelism is adjusted to a desired level. + */ + public enum FairRoutingType { + /** + * Don't do any fair routing. + */ + NONE("none"), + + /** + * TEZ-2962 Based on input data statistics the parallelism is decreased + * to a desired level by having one destination task process multiple + * consecutive partitions. + */ + REDUCE_PARALLELISM("reduce_parallelism"), + + /** + * Based on input data statistics the parallelism is adjusted + * to a desired level by having one destination task process multiple + * small partitions and multiple destination tasks process one + * large partition. Only works when there is one bipartite edge. + */ + FAIR_PARALLELISM("fair_parallelism"); + + private final String type; + + private FairRoutingType(String type) { + this.type = type; + } + + public final String getType() { + return type; + } + + public boolean reduceParallelismEnabled() { + return equals(FairRoutingType.REDUCE_PARALLELISM); + } + + public boolean fairParallelismEnabled() { + return equals(FairRoutingType.FAIR_PARALLELISM); + } + + public boolean enabled() { + return !equals(FairRoutingType.NONE); + } + + public static FairRoutingType fromString(String type) { + if (type != null) { + for (FairRoutingType b : FairRoutingType.values()) { + if (type.equalsIgnoreCase(b.type)) { + return b; + } + } + } + throw new IllegalArgumentException("Invalid type " + type); + } + } + + static class FairSourceVertexInfo extends SourceVertexInfo { + // mapping from destination task id to DestinationTaskInputsProperty + private final HashMap<Integer, DestinationTaskInputsProperty> + destinationInputsProperties = new HashMap<>(); + + FairSourceVertexInfo(final EdgeProperty edgeProperty, + int totalTasksToSchedule) { + super(edgeProperty, totalTasksToSchedule); + } + public HashMap<Integer, DestinationTaskInputsProperty> + getDestinationInputsProperties() { + return destinationInputsProperties; + } + } + + @Override + SourceVertexInfo createSourceVertexInfo(EdgeProperty edgeProperty, + int numTasks) { + return new FairSourceVertexInfo(edgeProperty, numTasks); + } + + + FairShuffleVertexManagerConfig mgrConfig; + + public FairShuffleVertexManager(VertexManagerPluginContext context) { + super(context); + } + + @Override + protected void onVertexStartedCheck() { + super.onVertexStartedCheck(); + if (bipartiteSources > 1 && + (mgrConfig.getFairRoutingType().fairParallelismEnabled())) { + // TODO TEZ-3500 + throw new TezUncheckedException( + "Having more than one destination task process same partition(s) " + + "only works with one bipartite source."); + } + } + + static long ceil(long a, long b) { + return (a + (b - 1)) / b; + } + + public long[] estimatePartitionSize() { + boolean partitionStatsReported = false; + int numOfPartitions = pendingTasks.size(); + long[] estimatedPartitionOutputSize = new long[numOfPartitions]; + for (int i = 0; i < numOfPartitions; i++) { + if (getCurrentlyKnownStatsAtIndex(i) > 0) { + partitionStatsReported = true; + break; + } + } + + if (!partitionStatsReported) { + // partition stats reporting isn't enabled at the source. Use + // expected source output size and assume all partitions are evenly + // distributed. + if (numOfPartitions > 0) { + long estimatedPerPartitionSize = + getExpectedTotalBipartiteSourceTasksOutputSize().divide( + BigInteger.valueOf(numOfPartitions)).longValue(); + for (int i = 0; i < numOfPartitions; i++) { + estimatedPartitionOutputSize[i] = estimatedPerPartitionSize; + } + } + } else { + for (int i = 0; i < numOfPartitions; i++) { + estimatedPartitionOutputSize[i] = + MB * getExpectedStatsAtIndex(i); + } + } + return estimatedPartitionOutputSize; + } + + /* + * The class calculates how partitions and source tasks should be + * grouped together. It allows a destination task to fetch a consecutive + * range of partitions from a consecutive range of source tasks to achieve + * optimal physical input size specified by desiredTaskInputDataSize. + * First it estimates the size of each partition at job completion based + * on the partition and output size of the completed tasks. The estimation + * is stored in estimatedPartitionOutputSize. + * Then it walks the partitions starting from beginning. + * If a partition is not greater than desiredTaskInputDataSize, it keeps + * accumulating the next partition until it is about to exceed + * desiredTaskInputDataSize. Then it will create a new destination task to + * fetch these small partitions in the range of + * {firstPartitionId, numOfPartitions} to from all source tasks. + * If a partition is larger than desiredTaskInputDataSize, + * For FairRoutingType.REDUCE policy, it creates a new destination task to + * to fetch this large partition from all source tasks. + * For FairRoutingType.FAIR policy, it will create multiple destination tasks + * each of which will fetch the large partition from a range + * of source tasks. + */ + private class PartitionsGroupingCalculator + implements Iterable<DestinationTaskInputsProperty> { + + private final FairSourceVertexInfo sourceVertexInfo; + + // Estimated aggregated partition output size when the job is done. + private long[] estimatedPartitionOutputSize; + + // Intermediate states used to group partitions. + + // Total output size of partitions in current group. + private long sizeOfPartitions = 0; + // Total number of partitions in the current group. + private int numOfPartitions = 0; + // The first partition id in the current group. + private int firstPartitionId = 0; + // The # of source tasks a destination task consumes. + // When FAIR_PARALLELISM is enabled, there will be multiple destination + // tasks processing the same partition and each destination task will + // process a range of source tasks of that partition. For a given + // partition, the number of source tasks assigned to different destination + // tasks should differ by one at most and numOfBaseSourceTasks is the + // smaller value. numOfBaseDestinationTasks is the number of destination tasks that + // process numOfBaseSourceTasks source tasks. + // e.g. if 8 source tasks are assigned 3 destination tasks, the number of + // source tasks assigned to these 3 destination tasks are {2, 3, 3}. + // numOfBaseDestinationTasks == 1, numOfBaseSourceTasks == 2. + private int numOfBaseSourceTasks = 0; + private int numOfBaseDestinationTasks = 0; + public PartitionsGroupingCalculator(long[] estimatedPartitionOutputSize, + FairSourceVertexInfo sourceVertexInfo) { + this.estimatedPartitionOutputSize = estimatedPartitionOutputSize; + this.sourceVertexInfo = sourceVertexInfo; + } + + // Start the processing of the next group of partitions + private void startNextPartitionsGroup() { + this.firstPartitionId += this.numOfPartitions; + this.sizeOfPartitions = 0; + this.numOfPartitions = 0; + this.numOfBaseSourceTasks = 0; + this.numOfBaseDestinationTasks = 0; + } + + private int getNextPartitionId() { + return this.firstPartitionId + this.numOfPartitions; + } + + private void addNextPartition() { + if (hasPartitionsLeft()) { + this.sizeOfPartitions += + estimatedPartitionOutputSize[getNextPartitionId()]; + this.numOfPartitions++; + } + } + + private boolean hasPartitionsLeft() { + return getNextPartitionId() < this.estimatedPartitionOutputSize.length; + } + + private long getCurrentAndNextPartitionSize() { + return hasPartitionsLeft() ? this.sizeOfPartitions + + estimatedPartitionOutputSize[getNextPartitionId()] : + this.sizeOfPartitions; + } + + // For the current source output partition(s), decide how + // source tasks should be grouped. + private boolean computeSourceTasksGrouping() { + boolean finalizeCurrentPartitions = true; + int groupCount = Ints.checkedCast(ceil(getCurrentAndNextPartitionSize(), + config.getDesiredTaskInputDataSize())); + if (groupCount <= 1) { + // There is no enough data so far to reach desiredTaskInputDataSize. + addNextPartition(); + if (!hasPartitionsLeft()) { + // We have reached the last partition. + // Consume from all source tasks. + this.numOfBaseDestinationTasks = 1; + this.numOfBaseSourceTasks = this.sourceVertexInfo.numTasks; + } else { + finalizeCurrentPartitions = false; + } + } else if (numOfPartitions == 0) { + // The first partition in the current group exceeds + // desiredTaskInputDataSize. + addNextPartition(); + if (mgrConfig.getFairRoutingType().reduceParallelismEnabled()) { + // Consume from all source tasks + this.numOfBaseDestinationTasks = 1; + this.numOfBaseSourceTasks = this.sourceVertexInfo.numTasks; + } else { + // When groupCount > sourceVertexInfo.numTasks, it means + // sizeOfPartitions is too big so that even if + // we just have one destination task fetch from one source task the + // input size still exceeds desiredTaskInputDataSize. + if ((this.sourceVertexInfo.numTasks >= groupCount)) { + this.numOfBaseDestinationTasks = groupCount - + this.sourceVertexInfo.numTasks % groupCount; + this.numOfBaseSourceTasks = + this.sourceVertexInfo.numTasks / groupCount; + } else { + this.numOfBaseDestinationTasks = this.sourceVertexInfo.numTasks; + this.numOfBaseSourceTasks = 1; + } + } + } else { + // There are existing partitions in the current group. Adding the next + // partition causes the total size to exceed desiredTaskInputDataSize. + // Let us process the existing partitions in the current group. The + // next partition will be processed in the next group. + this.numOfBaseDestinationTasks = 1; + this.numOfBaseSourceTasks = this.sourceVertexInfo.numTasks; + } + return finalizeCurrentPartitions; + } + + @Override + public Iterator<DestinationTaskInputsProperty> iterator() { + return new UnmodifiableIterator<DestinationTaskInputsProperty>() { + private int j = 0; + private boolean visitedAtLeastOnce = false; + private int groupIndex = 0; + + // Get number of source tasks in the current group. + private int getNumOfSourceTasks() { + return groupIndex++ < numOfBaseDestinationTasks ? + numOfBaseSourceTasks : numOfBaseSourceTasks + 1; + } + + @Override + public boolean hasNext() { + return j < sourceVertexInfo.numTasks || !visitedAtLeastOnce; + } + + @Override + public DestinationTaskInputsProperty next() { + if (hasNext()) { + visitedAtLeastOnce = true; + int start = j; + int numOfSourceTasks = getNumOfSourceTasks(); + j += numOfSourceTasks; + return new DestinationTaskInputsProperty(firstPartitionId, + numOfPartitions, start, numOfSourceTasks); + } + throw new NoSuchElementException(); + } + }; + } + + public void compute() { + int destinationIndex = 0; + while (hasPartitionsLeft()) { + if (!computeSourceTasksGrouping()) { + continue; + } + Iterator<DestinationTaskInputsProperty> it = iterator(); + while(it.hasNext()) { + sourceVertexInfo.getDestinationInputsProperties().put( + destinationIndex,it.next()); + destinationIndex++; + } + startNextPartitionsGroup(); + } + } + } + + public ReconfigVertexParams computeRouting() { + int currentParallelism = pendingTasks.size(); + int finalTaskParallelism = 0; + long[] estimatedPartitionOutputSize = estimatePartitionSize(); + for (Map.Entry<String, SourceVertexInfo> vInfo : getBipartiteInfo()) { + FairSourceVertexInfo info = (FairSourceVertexInfo)vInfo.getValue(); + computeParallelism(estimatedPartitionOutputSize, info); + if (finalTaskParallelism != 0) { + Preconditions.checkState( + finalTaskParallelism == info.getDestinationInputsProperties().size(), + "the parallelism shall be the same for source vertices"); + } + finalTaskParallelism = info.getDestinationInputsProperties().size(); + + FairEdgeConfiguration fairEdgeConfig = new FairEdgeConfiguration( + currentParallelism, info.getDestinationInputsProperties()); + EdgeManagerPluginDescriptor descriptor = + EdgeManagerPluginDescriptor.create( + FairShuffleEdgeManager.class.getName()); + descriptor.setUserPayload(fairEdgeConfig.getBytePayload()); + vInfo.getValue().newDescriptor = descriptor; + } + ReconfigVertexParams params = new ReconfigVertexParams( + finalTaskParallelism, null); + + return params; + } + + @Override + void postReconfigVertex() { + } + + @Override + void processPendingTasks() { + } + + private void computeParallelism(long[] estimatedPartitionOutputSize, + FairSourceVertexInfo sourceVertexInfo) { + PartitionsGroupingCalculator calculator = new PartitionsGroupingCalculator( + estimatedPartitionOutputSize, sourceVertexInfo); + calculator.compute(); + } + + @Override + List<ScheduleTaskRequest> getTasksToSchedule( + TaskAttemptIdentifier completedSourceAttempt) { + float minSourceVertexCompletedTaskFraction = + getMinSourceVertexCompletedTaskFraction(); + int numTasksToSchedule = getNumOfTasksToScheduleAndLog( + minSourceVertexCompletedTaskFraction); + if (numTasksToSchedule > 0) { + boolean scheduleAll = + (numTasksToSchedule == pendingTasks.size()); + List<ScheduleTaskRequest> tasksToSchedule = + Lists.newArrayListWithCapacity(numTasksToSchedule); + + Iterator<PendingTaskInfo> it = pendingTasks.iterator(); + FairSourceVertexInfo srcInfo = null; + int srcTaskId = 0; + if (completedSourceAttempt != null) { + srcTaskId = completedSourceAttempt.getTaskIdentifier().getIdentifier(); + String srcVertexName = completedSourceAttempt.getTaskIdentifier().getVertexIdentifier().getName(); + srcInfo = (FairSourceVertexInfo)getSourceVertexInfo(srcVertexName); + } + while (it.hasNext() && numTasksToSchedule > 0) { + Integer taskIndex = it.next().getIndex(); + // filter out those destination tasks that don't depend on + // this completed source task. + // destinationInputsProperties's size could be 0 if routing computation + // is skipped. + if (!scheduleAll && config.isAutoParallelismEnabled() + && srcInfo != null && srcInfo.getDestinationInputsProperties().size() > 0) { + DestinationTaskInputsProperty property = + srcInfo.getDestinationInputsProperties().get(taskIndex); + if (!property.isSourceTaskInRange(srcTaskId)) { + LOG.debug("completedSourceTaskIndex {} and taskIndex {} don't " + + "connect.", srcTaskId, taskIndex); + continue; + } + } + tasksToSchedule.add(ScheduleTaskRequest.create(taskIndex, null)); + it.remove(); + numTasksToSchedule--; + } + return tasksToSchedule; + } + return null; + } + + static class FairShuffleVertexManagerConfig extends ShuffleVertexManagerBaseConfig { + final FairRoutingType fairRoutingType; + public FairShuffleVertexManagerConfig(final boolean enableAutoParallelism, + final long desiredTaskInputDataSize, final float slowStartMinFraction, + final float slowStartMaxFraction, final FairRoutingType fairRoutingType) { + super(enableAutoParallelism, desiredTaskInputDataSize, + slowStartMinFraction, slowStartMaxFraction); + this.fairRoutingType = fairRoutingType; + LOG.info("fairRoutingType {}", this.fairRoutingType); + } + FairRoutingType getFairRoutingType() { + return fairRoutingType; + } + } + + @Override + ShuffleVertexManagerBaseConfig initConfiguration() { + float slowStartMinFraction = conf.getFloat( + TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION, + TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION_DEFAULT); + FairRoutingType fairRoutingType = FairRoutingType.fromString( + conf.get(TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL, + TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL_DEFAULT)); + + mgrConfig = new FairShuffleVertexManagerConfig( + fairRoutingType.enabled(), + conf.getLong( + TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE, + TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE_DEFAULT), + slowStartMinFraction, + conf.getFloat( + TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION, + Math.max(slowStartMinFraction, + TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION_DEFAULT)), + fairRoutingType); + return mgrConfig; + } + + /** + * Create a {@link VertexManagerPluginDescriptor} builder that can be used to + * configure the plugin. + * + * @param conf + * {@link Configuration} May be modified in place. May be null if the + * configuration parameters are to be set only via code. If + * configuration values may be changed at runtime via a config file + * then pass in a {@link Configuration} that is initialized from a + * config file. The parameters that are not overridden in code will + * be derived from the Configuration object. + * @return {@link FairShuffleVertexManagerConfigBuilder} + */ + public static FairShuffleVertexManagerConfigBuilder + createConfigBuilder(@Nullable Configuration conf) { + return new FairShuffleVertexManagerConfigBuilder(conf); + } + + /** + * Helper class to configure ShuffleVertexManager + */ + public static final class FairShuffleVertexManagerConfigBuilder { + private final Configuration conf; + + private FairShuffleVertexManagerConfigBuilder(@Nullable Configuration conf) { + if (conf == null) { + this.conf = new Configuration(false); + } else { + this.conf = conf; + } + } + + public FairShuffleVertexManagerConfigBuilder setAutoParallelism( + FairRoutingType fairRoutingType) { + conf.set(TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL, + fairRoutingType.toString()); + return this; + } + + public FairShuffleVertexManagerConfigBuilder + setSlowStartMinSrcCompletionFraction(float minFraction) { + conf.setFloat(TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION, + minFraction); + return this; + } + + public FairShuffleVertexManagerConfigBuilder + setSlowStartMaxSrcCompletionFraction(float maxFraction) { + conf.setFloat(TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION, + maxFraction); + return this; + } + + public FairShuffleVertexManagerConfigBuilder setDesiredTaskInputSize( + long desiredTaskInputSize) { + conf.setLong(TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE, + desiredTaskInputSize); + return this; + } + + public VertexManagerPluginDescriptor build() { + VertexManagerPluginDescriptor desc = + VertexManagerPluginDescriptor.create( + FairShuffleVertexManager.class.getName()); + + try { + return desc.setUserPayload(TezUtils.createUserPayloadFromConf( + this.conf)); + } catch (IOException e) { + throw new TezUncheckedException(e); + } + } + } +} http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java ---------------------------------------------------------------------- diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java index 9937bd1..55a6ced 100644 --- a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java +++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java @@ -75,7 +75,7 @@ public class ShuffleVertexManager extends ShuffleVertexManagerBase { /** * Enables automatic parallelism determination for the vertex. Based on input data - * statisitics the parallelism is decreased to a desired level. + * statistics the parallelism is decreased to a desired level. */ public static final String TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL = "tez.shuffle-vertex-manager.enable.auto-parallel"; @@ -266,7 +266,6 @@ public class ShuffleVertexManager extends ShuffleVertexManagerBase { + sourceIndex % partitionRange; return EventRouteMetadata.create(1, new int[]{targetIndex}); } - @Override @@ -447,19 +446,8 @@ public class ShuffleVertexManager extends ShuffleVertexManagerBase { // Change this to use per partition stats for more accuracy TEZ-2962. // Instead of aggregating overall size and then dividing equally - coalesce partitions until // desired per partition size is achieved. - BigInteger expectedTotalSourceTasksOutputSize = BigInteger.ZERO; - for (Map.Entry<String, SourceVertexInfo> vInfo : getBipartiteInfo()) { - SourceVertexInfo srcInfo = vInfo.getValue(); - if (srcInfo.numTasks > 0 && srcInfo.numVMEventsReceived > 0) { - // this assumes that 1 vmEvent is received per completed task - TEZ-2961 - // Estimate total size by projecting based on the current average size per event - BigInteger srcOutputSize = BigInteger.valueOf(srcInfo.outputSize); - BigInteger srcNumTasks = BigInteger.valueOf(srcInfo.numTasks); - BigInteger srcNumVMEventsReceived = BigInteger.valueOf(srcInfo.numVMEventsReceived); - BigInteger expectedSrcOutputSize = srcOutputSize.multiply(srcNumTasks).divide(srcNumVMEventsReceived); - expectedTotalSourceTasksOutputSize = expectedTotalSourceTasksOutputSize.add(expectedSrcOutputSize); - } - } + BigInteger expectedTotalSourceTasksOutputSize = + getExpectedTotalBipartiteSourceTasksOutputSize(); LOG.info("Expected output: {} based on actual output: {} from {} vertex " + "manager events. desiredTaskInputSize: {} max slow start tasks: {} " + @@ -527,7 +515,13 @@ public class ShuffleVertexManager extends ShuffleVertexManagerBase { EdgeManagerPluginDescriptor descriptor = EdgeManagerPluginDescriptor.create(CustomShuffleEdgeManager.class.getName()); descriptor.setUserPayload(edgeManagerConfig.toUserPayload()); - ReconfigVertexParams params = new ReconfigVertexParams(finalTaskParallelism, null, descriptor); + + Iterable<Map.Entry<String, SourceVertexInfo>> bipartiteItr = getBipartiteInfo(); + for(Map.Entry<String, SourceVertexInfo> entry : bipartiteItr) { + entry.getValue().newDescriptor = descriptor; + } + ReconfigVertexParams params = + new ReconfigVertexParams(finalTaskParallelism, null); return params; } @@ -623,13 +617,14 @@ public class ShuffleVertexManager extends ShuffleVertexManagerBase { Preconditions.checkState(index < targetIndexes.length, "index=" + index +", targetIndexes length=" + targetIndexes.length); int[] mapping = targetIndexes[index]; - long totalStats = 0; + int partitionStats = 0; for (int i : mapping) { - totalStats += stats[i]; + partitionStats += getCurrentlyKnownStatsAtIndex(i); } - computedPartitionSizes |= taskInfo.setInputStats(totalStats); + computedPartitionSizes |= taskInfo.setInputStats(partitionStats); } else { - computedPartitionSizes |= taskInfo.setInputStats(stats[index]); + computedPartitionSizes |= taskInfo.setInputStats( + getCurrentlyKnownStatsAtIndex(index)); } } return computedPartitionSizes; @@ -637,8 +632,6 @@ public class ShuffleVertexManager extends ShuffleVertexManagerBase { - - /** * Create a {@link VertexManagerPluginDescriptor} builder that can be used to * configure the plugin. http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManagerBase.java ---------------------------------------------------------------------- diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManagerBase.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManagerBase.java index dc6cd3b..967d0ea 100644 --- a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManagerBase.java +++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManagerBase.java @@ -35,12 +35,12 @@ import org.apache.tez.dag.api.EdgeProperty; import org.apache.tez.dag.api.EdgeProperty.DataMovementType; import org.apache.tez.dag.api.InputDescriptor; import org.apache.tez.dag.api.TezUncheckedException; -import org.apache.tez.dag.api.VertexManagerPlugin; import org.apache.tez.dag.api.VertexManagerPluginContext; +import org.apache.tez.dag.api.VertexLocationHint; +import org.apache.tez.dag.api.VertexManagerPlugin; import org.apache.tez.dag.api.VertexManagerPluginContext.ScheduleTaskRequest; import org.apache.tez.dag.api.event.VertexState; import org.apache.tez.dag.api.event.VertexStateUpdate; -import org.apache.tez.dag.api.VertexLocationHint; import org.apache.tez.runtime.library.utils.DATA_RANGE_IN_MB; import org.roaringbitmap.RoaringBitmap; import org.slf4j.Logger; @@ -58,6 +58,8 @@ import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.VertexMan import java.io.DataInputStream; import java.io.IOException; +import java.math.BigInteger; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.BitSet; import java.util.EnumSet; import java.util.HashMap; @@ -65,13 +67,11 @@ import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.zip.Inflater; /** - * Starts scheduling tasks when number of completed source tasks crosses - * <code>slowStartMinFraction</code> and schedules all tasks - * when <code>slowStartMaxFraction</code> is reached + * It provides common functions used by ShuffleVertexManager and + * FairShuffleVertexManager. */ @Private @Evolving @@ -102,7 +102,6 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { int totalTasksToSchedule = 0; @VisibleForTesting - long[] stats; //approximate amount of data to be fetched Configuration conf; ShuffleVertexManagerBaseConfig config; // requires synchronized access @@ -132,10 +131,14 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { int numTasks; int numVMEventsReceived; long outputSize; + int[] statsInMB; + EdgeManagerPluginDescriptor newDescriptor; - SourceVertexInfo(final EdgeProperty edgeProperty) { + SourceVertexInfo(final EdgeProperty edgeProperty, + int totalTasksToSchedule) { this.edgeProperty = edgeProperty; this.finishedTaskSet = new BitSet(); + this.statsInMB = new int[totalTasksToSchedule]; } int getNumTasks() { @@ -145,11 +148,20 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { int getNumCompletedTasks() { return finishedTaskSet.cardinality(); } + int getExpectedStatsInMBAtIndex(int index) { + return (numVMEventsReceived == 0) ? + 0: statsInMB[index] * numTasks / numVMEventsReceived; + } + } + + SourceVertexInfo createSourceVertexInfo(EdgeProperty edgeProperty, + int numTasks) { + return new SourceVertexInfo(edgeProperty, numTasks); } static class PendingTaskInfo { final private int index; - private long inputStats; + private int inputStats; public PendingTaskInfo(int index) { this.index = index; @@ -161,11 +173,11 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { public int getIndex() { return index; } - public long getInputStats() { + public int getInputStats() { return inputStats; } // return true if stat is set. - public boolean setInputStats(long inputStats) { + public boolean setInputStats(int inputStats) { if (inputStats > 0 && this.inputStats != inputStats) { this.inputStats = inputStats; return true; @@ -178,14 +190,11 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { static class ReconfigVertexParams { final private int finalParallelism; final private VertexLocationHint locationHint; - final private EdgeManagerPluginDescriptor descriptor; public ReconfigVertexParams(final int finalParallelism, - final VertexLocationHint locationHint, - final EdgeManagerPluginDescriptor descriptor) { + final VertexLocationHint locationHint) { this.finalParallelism = finalParallelism; this.locationHint = locationHint; - this.descriptor = descriptor; } public int getFinalParallelism() { @@ -194,9 +203,6 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { public VertexLocationHint getLocationHint() { return locationHint; } - public EdgeManagerPluginDescriptor getDescriptor() { - return descriptor; - } } public ShuffleVertexManagerBase(VertexManagerPluginContext context) { @@ -209,7 +215,8 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { // examine edges after vertex started because until then these may not have been defined Map<String, EdgeProperty> inputs = getContext().getInputVertexEdgeProperties(); for(Map.Entry<String, EdgeProperty> entry : inputs.entrySet()) { - srcVertexInfo.put(entry.getKey(), new SourceVertexInfo(entry.getValue())); + srcVertexInfo.put(entry.getKey(), createSourceVertexInfo(entry.getValue(), + getContext().getVertexNumTasks(getContext().getVertexName()))); // TODO what if derived class has already called this // register for status update from all source vertices getContext().registerForVertexStateUpdates(entry.getKey(), @@ -218,9 +225,7 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { bipartiteSources++; } } - if(bipartiteSources == 0) { - throw new TezUncheckedException("Atleast 1 bipartite source should exist"); - } + onVertexStartedCheck(); for (VertexStateUpdate stateUpdate : pendingStateUpdates) { handleVertexStateUpdate(stateUpdate); @@ -249,6 +254,11 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { processPendingTasks(null); } + protected void onVertexStartedCheck() { + if(bipartiteSources == 0) { + throw new TezUncheckedException("At least 1 bipartite source should exist"); + } + } @Override public synchronized void onSourceTaskCompleted(TaskAttemptIdentifier attempt) { @@ -274,8 +284,10 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { } @VisibleForTesting - void parsePartitionStats(RoaringBitmap partitionStats) { - Preconditions.checkState(stats != null, "Stats should be initialized"); + void parsePartitionStats(SourceVertexInfo srcInfo, + RoaringBitmap partitionStats) { + Preconditions.checkState(srcInfo.statsInMB != null, + "Stats should be initialized"); Iterator<Integer> it = partitionStats.iterator(); final DATA_RANGE_IN_MB[] RANGES = DATA_RANGE_IN_MB.values(); final int RANGE_LEN = RANGES.length; @@ -285,14 +297,15 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { int rangeIndex = ((pos) % RANGE_LEN); //Add to aggregated stats and normalize to DATA_RANGE_IN_MB. if (RANGES[rangeIndex].getSizeInMB() > 0) { - stats[index] += RANGES[rangeIndex].getSizeInMB(); + srcInfo.statsInMB[index] += RANGES[rangeIndex].getSizeInMB(); } } } - void parseDetailedPartitionStats(List<Integer> partitionStats) { + void parseDetailedPartitionStats(SourceVertexInfo srcInfo, + List<Integer> partitionStats) { for (int i=0; i<partitionStats.size(); i++) { - stats[i] += partitionStats.get(i); + srcInfo.statsInMB[i] += partitionStats.get(i); } } @@ -344,7 +357,7 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { NonSyncByteArrayInputStream bin = new NonSyncByteArrayInputStream(rawData); partitionStats.deserialize(new DataInputStream(bin)); - parsePartitionStats(partitionStats); + parsePartitionStats(srcInfo, partitionStats); } catch (IOException e) { throw new TezUncheckedException(e); @@ -352,7 +365,7 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { } else if (proto.hasDetailedPartitionStats()) { List<Integer> detailedPartitionStats = proto.getDetailedPartitionStats().getSizeInMbList(); - parseDetailedPartitionStats(detailedPartitionStats); + parseDetailedPartitionStats(srcInfo, detailedPartitionStats); } srcInfo.numVMEventsReceived++; srcInfo.outputSize += sourceTaskOutputSize; @@ -361,11 +374,11 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { if (LOG.isDebugEnabled()) { LOG.debug("For attempt: {} received info of output size: {}" - + " vertex numEventsReceived: {} vertex output size: {}" - + " total numEventsReceived: {} total output size: {}", - vmEvent.getProducerAttemptIdentifier(), sourceTaskOutputSize, - srcInfo.numVMEventsReceived, srcInfo.outputSize, - numVertexManagerEventsReceived, completedSourceTasksOutputSize); + + " vertex numEventsReceived: {} vertex output size: {}" + + " total numEventsReceived: {} total output size: {}", + vmEvent.getProducerAttemptIdentifier(), sourceTaskOutputSize, + srcInfo.numVMEventsReceived, srcInfo.outputSize, + numVertexManagerEventsReceived, completedSourceTasksOutputSize); } } @@ -379,9 +392,6 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { pendingTasks.add(new PendingTaskInfo(i)); } totalTasksToSchedule = pendingTasks.size(); - if (stats == null) { - stats = new long[totalTasksToSchedule]; // TODO lost previous data - } } /** @@ -427,6 +437,41 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { } } + BigInteger getExpectedTotalBipartiteSourceTasksOutputSize() { + BigInteger expectedTotalSourceTasksOutputSize = BigInteger.ZERO; + for (Map.Entry<String, SourceVertexInfo> vInfo : getBipartiteInfo()) { + SourceVertexInfo srcInfo = vInfo.getValue(); + if (srcInfo.numTasks > 0 && srcInfo.numVMEventsReceived > 0) { + // this assumes that 1 vmEvent is received per completed task - TEZ-2961 + // Estimate total size by projecting based on the current average size per event + BigInteger srcOutputSize = BigInteger.valueOf(srcInfo.outputSize); + BigInteger srcNumTasks = BigInteger.valueOf(srcInfo.numTasks); + BigInteger srcNumVMEventsReceived = BigInteger.valueOf(srcInfo.numVMEventsReceived); + BigInteger expectedSrcOutputSize = srcOutputSize.multiply( + srcNumTasks).divide(srcNumVMEventsReceived); + expectedTotalSourceTasksOutputSize = + expectedTotalSourceTasksOutputSize.add(expectedSrcOutputSize); + } + } + return expectedTotalSourceTasksOutputSize; + } + + int getCurrentlyKnownStatsAtIndex(int index) { + int stats = 0; + for(SourceVertexInfo entry : getAllSourceVertexInfo()) { + stats += entry.statsInMB[index]; + } + return stats; + } + + int getExpectedStatsAtIndex(int index) { + int stats = 0; + for(SourceVertexInfo entry : getAllSourceVertexInfo()) { + stats += entry.getExpectedStatsInMBAtIndex(index); + } + return stats; + } + /** * Subclass might return null to indicate there is no new routing. */ @@ -447,7 +492,7 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { if (computeRoutingAction.equals(computeRoutingAction.COMPUTE)) { ReconfigVertexParams params = computeRouting(); if (params != null) { - reconfigVertex(params.getFinalParallelism(), params.getDescriptor()); + reconfigVertex(params.getFinalParallelism()); updatePendingTasks(); postReconfigVertex(); } @@ -489,6 +534,14 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { } } + Iterable<SourceVertexInfo> getAllSourceVertexInfo() { + return srcVertexInfo.values(); + } + + SourceVertexInfo getSourceVertexInfo(String vertextName) { + return srcVertexInfo.get(vertextName); + } + Iterable<Map.Entry<String, SourceVertexInfo>> getBipartiteInfo() { return Iterables.filter(srcVertexInfo.entrySet(), new Predicate<Map.Entry<String,SourceVertexInfo>>() { @@ -753,20 +806,18 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin { // Not allowing this for now. Nothing to do. } - private void reconfigVertex(final int finalTaskParallelism, - final EdgeManagerPluginDescriptor edgeManagerDescriptor) { + private void reconfigVertex(final int finalTaskParallelism) { Map<String, EdgeProperty> edgeProperties = new HashMap<String, EdgeProperty>(bipartiteSources); Iterable<Map.Entry<String, SourceVertexInfo>> bipartiteItr = getBipartiteInfo(); for(Map.Entry<String, SourceVertexInfo> entry : bipartiteItr) { String vertex = entry.getKey(); EdgeProperty oldEdgeProp = entry.getValue().edgeProperty; - EdgeProperty newEdgeProp = EdgeProperty.create(edgeManagerDescriptor, + EdgeProperty newEdgeProp = EdgeProperty.create(entry.getValue().newDescriptor, oldEdgeProp.getDataSourceType(), oldEdgeProp.getSchedulingType(), oldEdgeProp.getEdgeSource(), oldEdgeProp.getEdgeDestination()); edgeProperties.put(vertex, newEdgeProp); } - getContext().reconfigureVertex(finalTaskParallelism, null, edgeProperties); } } http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/src/main/proto/FairShufflePayloads.proto ---------------------------------------------------------------------- diff --git a/tez-runtime-library/src/main/proto/FairShufflePayloads.proto b/tez-runtime-library/src/main/proto/FairShufflePayloads.proto new file mode 100644 index 0000000..334cbc9 --- /dev/null +++ b/tez-runtime-library/src/main/proto/FairShufflePayloads.proto @@ -0,0 +1,37 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +option java_package = "org.apache.tez.dag.library.vertexmanager"; +option java_outer_classname = "FairShuffleUserPayloads"; +option java_generate_equals_and_hash = true; + +message RangeProto { + optional int32 first_index = 1; + optional int32 num_of_indexes = 2; +} + +message FairShuffleEdgeManagerDestinationTaskPropProto { + optional int32 destination_task_index = 1; + optional RangeProto partitions = 2; + optional RangeProto source_tasks = 3; +} + +message FairShuffleEdgeManagerConfigPayloadProto { + optional int32 num_buckets = 1; + repeated FairShuffleEdgeManagerDestinationTaskPropProto destinationTaskProps = 2; +} http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestFairShuffleVertexManager.java ---------------------------------------------------------------------- diff --git a/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestFairShuffleVertexManager.java b/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestFairShuffleVertexManager.java new file mode 100644 index 0000000..9c94c14 --- /dev/null +++ b/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestFairShuffleVertexManager.java @@ -0,0 +1,347 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.dag.library.vertexmanager; + +import com.google.common.collect.Lists; + +import org.apache.hadoop.conf.Configuration; +import org.apache.tez.dag.api.EdgeManagerPlugin; +import org.apache.tez.dag.api.EdgeManagerPluginOnDemand; +import org.apache.tez.dag.api.EdgeProperty; +import org.apache.tez.dag.api.EdgeProperty.SchedulingType; +import org.apache.tez.dag.api.InputDescriptor; +import org.apache.tez.dag.api.OutputDescriptor; +import org.apache.tez.dag.api.TezUncheckedException; +import org.apache.tez.dag.api.VertexLocationHint; +import org.apache.tez.dag.api.VertexManagerPluginContext; +import org.apache.tez.dag.api.event.VertexState; +import org.apache.tez.dag.api.event.VertexStateUpdate; +import org.apache.tez.dag.library.vertexmanager.FairShuffleVertexManager.FairRoutingType; +import org.apache.tez.runtime.api.TaskAttemptIdentifier; +import org.apache.tez.runtime.api.events.VertexManagerEvent; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.anyList; +import static org.mockito.Mockito.anyMap; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@SuppressWarnings({ "unchecked", "rawtypes" }) +public class TestFairShuffleVertexManager + extends TestShuffleVertexManagerUtils { + List<TaskAttemptIdentifier> emptyCompletions = null; + + @Test(timeout = 5000) + public void testAutoParallelismConfig() throws Exception { + FairShuffleVertexManager manager; + + final List<Integer> scheduledTasks = Lists.newLinkedList(); + + final VertexManagerPluginContext mockContext = createVertexManagerContext( + "Vertex1", 2, "Vertex2", 2, "Vertex3", 2, + "Vertex4", 4, scheduledTasks, null); + + manager = createManager(null, mockContext, null, 0.5f); + verify(mockContext, times(1)).vertexReconfigurationPlanned(); // Tez notified of reconfig + Assert.assertTrue(manager.config.isAutoParallelismEnabled()); + Assert.assertTrue(manager.config.getDesiredTaskInputDataSize() == 1000l * MB); + Assert.assertTrue(manager.config.getMinFraction() == 0.25f); + Assert.assertTrue(manager.config.getMaxFraction() == 0.5f); + + manager = createManager(null, mockContext, null, null, null, null); + verify(mockContext, times(1)).vertexReconfigurationPlanned(); // Tez not notified of reconfig + + Assert.assertTrue(!manager.config.isAutoParallelismEnabled()); + Assert.assertTrue(manager.config.getDesiredTaskInputDataSize() == + FairShuffleVertexManager.TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE_DEFAULT); + Assert.assertTrue(manager.config.getMinFraction() == + FairShuffleVertexManager.TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION_DEFAULT); + Assert.assertTrue(manager.config.getMaxFraction() == + FairShuffleVertexManager.TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION_DEFAULT); + } + + @Test(timeout = 5000) + public void testInvalidSetup() { + Configuration conf = new Configuration(); + ShuffleVertexManagerBase manager; + + final List<Integer> scheduledTasks = Lists.newLinkedList(); + + final VertexManagerPluginContext mockContext = createVertexManagerContext( + "Vertex1", 2, "Vertex2", 2, "Vertex3", 2, + "Vertex4", 4, scheduledTasks, null); + + // fail if there are more than one bipartite for FAIR_PARALLELISM + try { + manager = createFairShuffleVertexManager(conf, mockContext, + FairRoutingType.FAIR_PARALLELISM, 1000 * MB, 0.001f, 0.001f); + manager.onVertexStarted(emptyCompletions); + Assert.assertFalse(true); + } catch (TezUncheckedException e) { + Assert.assertTrue(e.getMessage().contains( + "Having more than one destination task process same partition(s) " + + "only works with one bipartite source.")); + } + } + + @Test(timeout = 5000) + public void testReduceSchedulingWithPartitionStats() throws Exception { + final Map<String, EdgeManagerPlugin> newEdgeManagers = + new HashMap<String, EdgeManagerPlugin>(); + testSchedulingWithPartitionStats(FairRoutingType.REDUCE_PARALLELISM, + 2, 2, newEdgeManagers); + EdgeManagerPluginOnDemand edgeManager = + (EdgeManagerPluginOnDemand)newEdgeManagers.values().iterator().next(); + + // The first destination task fetches two partitions from all source tasks. + // 6 == 3 source tasks * 2 merged partitions + Assert.assertEquals(6, edgeManager.getNumDestinationTaskPhysicalInputs(0)); + EdgeManagerPluginOnDemand.EventRouteMetadata routeMetadata; + for (int sourceTaskIndex = 0; sourceTaskIndex < 3; sourceTaskIndex++) { + for (int j = 0; j < 2; j++) { + routeMetadata = (j == 0) ? + edgeManager.routeCompositeDataMovementEventToDestination( + sourceTaskIndex, 0) : + edgeManager.routeInputSourceTaskFailedEventToDestination( + sourceTaskIndex, 0); + Assert.assertEquals(2, routeMetadata.getNumEvents()); + if (j == 0) { + Assert.assertArrayEquals(new int[]{0, 1}, + routeMetadata.getSourceIndices()); + } + Assert.assertArrayEquals( + new int[]{0 + sourceTaskIndex * 2, 1 + sourceTaskIndex * 2}, + routeMetadata.getTargetIndices()); + } + } + } + + @Test(timeout = 5000) + public void testFairSchedulingWithPartitionStats() throws Exception { + final Map<String, EdgeManagerPlugin> newEdgeManagers = + new HashMap<String, EdgeManagerPlugin>(); + testSchedulingWithPartitionStats(FairRoutingType.FAIR_PARALLELISM, + 3, 2, newEdgeManagers); + + // Get the first edgeManager which is SCATTER_GATHER. + EdgeManagerPluginOnDemand edgeManager = + (EdgeManagerPluginOnDemand)newEdgeManagers.values().iterator().next(); + + // The first destination task fetches two partitions from all source tasks. + // 6 == 3 source tasks * 2 merged partitions + Assert.assertEquals(6, edgeManager.getNumDestinationTaskPhysicalInputs(0)); + EdgeManagerPluginOnDemand.EventRouteMetadata routeMetadata; + for (int sourceTaskIndex = 0; sourceTaskIndex < 3; sourceTaskIndex++) { + for (int j = 0; j < 2; j++) { + routeMetadata = (j == 0) ? + edgeManager.routeCompositeDataMovementEventToDestination( + sourceTaskIndex, 0) : + edgeManager.routeInputSourceTaskFailedEventToDestination( + sourceTaskIndex, 0); + Assert.assertEquals(2, routeMetadata.getNumEvents()); + if (j == 0) { + Assert.assertArrayEquals(new int[]{0, 1}, + routeMetadata.getSourceIndices()); + } + Assert.assertArrayEquals( + new int[]{0 + sourceTaskIndex * 2, 1 + sourceTaskIndex * 2}, + routeMetadata.getTargetIndices()); + } + } + + // The 2nd destination task fetches one partition from the first source + // task. + Assert.assertEquals(1, edgeManager.getNumDestinationTaskPhysicalInputs(1)); + for (int j = 0; j < 2; j++) { + routeMetadata = (j == 0) ? + edgeManager.routeCompositeDataMovementEventToDestination( + 0, 1) : + edgeManager.routeInputSourceTaskFailedEventToDestination( + 0, 1); + Assert.assertEquals(1, routeMetadata.getNumEvents()); + if (j == 0) { + Assert.assertEquals(2, routeMetadata.getSourceIndices()[0]); + } + Assert.assertEquals(0, routeMetadata.getTargetIndices()[0]); + } + + // The 3rd destination task fetches one partition from the 2nd and 3rd + // source task. + Assert.assertEquals(2, edgeManager.getNumDestinationTaskPhysicalInputs(2)); + for (int sourceTaskIndex = 1; sourceTaskIndex < 3; sourceTaskIndex++) { + for (int j = 0; j < 2; j++) { + routeMetadata = (j == 0) ? + edgeManager.routeCompositeDataMovementEventToDestination( + sourceTaskIndex, 2) : + edgeManager.routeInputSourceTaskFailedEventToDestination( + sourceTaskIndex, 2); + Assert.assertEquals(1, routeMetadata.getNumEvents()); + if (j == 0) { + Assert.assertEquals(2, routeMetadata.getSourceIndices()[0]); + } + Assert.assertEquals(sourceTaskIndex - 1, + routeMetadata.getTargetIndices()[0]); + } + } + } + + // Create a DAG with one destination vertexes connected to 3 source vertexes. + // There are 3 tasks for each vertex. One edge is of type SCATTER_GATHER. + // The other edges are BROADCAST. + private void testSchedulingWithPartitionStats( + FairRoutingType fairRoutingType, int expectedScheduledTasks, + int expectedNumDestinationConsumerTasks, + Map<String, EdgeManagerPlugin> newEdgeManagers) + throws Exception { + Configuration conf = new Configuration(); + FairShuffleVertexManager manager; + + HashMap<String, EdgeProperty> mockInputVertices = new HashMap<String, EdgeProperty>(); + String r1 = "R1"; + final int numOfTasksInr1 = 3; + EdgeProperty eProp1 = EdgeProperty.create( + EdgeProperty.DataMovementType.SCATTER_GATHER, + EdgeProperty.DataSourceType.PERSISTED, + SchedulingType.SEQUENTIAL, + OutputDescriptor.create("out"), + InputDescriptor.create("in")); + String m2 = "M2"; + final int numOfTasksInM2 = 3; + EdgeProperty eProp2 = EdgeProperty.create( + EdgeProperty.DataMovementType.BROADCAST, + EdgeProperty.DataSourceType.PERSISTED, + SchedulingType.SEQUENTIAL, + OutputDescriptor.create("out"), + InputDescriptor.create("in")); + String m3 = "M3"; + final int numOfTasksInM3 = 3; + EdgeProperty eProp3 = EdgeProperty.create( + EdgeProperty.DataMovementType.BROADCAST, + EdgeProperty.DataSourceType.PERSISTED, + SchedulingType.SEQUENTIAL, + OutputDescriptor.create("out"), + InputDescriptor.create("in")); + + final String mockManagedVertexId = "R2"; + final int numOfTasksInDestination = 3; + + mockInputVertices.put(r1, eProp1); + mockInputVertices.put(m2, eProp2); + mockInputVertices.put(m3, eProp3); + + final VertexManagerPluginContext mockContext = mock(VertexManagerPluginContext.class); + when(mockContext.getInputVertexEdgeProperties()).thenReturn(mockInputVertices); + when(mockContext.getVertexName()).thenReturn(mockManagedVertexId); + when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(numOfTasksInDestination); + when(mockContext.getVertexNumTasks(r1)).thenReturn(numOfTasksInr1); + when(mockContext.getVertexNumTasks(m2)).thenReturn(numOfTasksInM2); + when(mockContext.getVertexNumTasks(m3)).thenReturn(numOfTasksInM3); + + final List<Integer> scheduledTasks = Lists.newLinkedList(); + doAnswer(new ScheduledTasksAnswer(scheduledTasks)).when( + mockContext).scheduleTasks(anyList()); + + doAnswer(new reconfigVertexAnswer(mockContext, mockManagedVertexId, + newEdgeManagers)).when(mockContext).reconfigureVertex( + anyInt(), any(VertexLocationHint.class), anyMap()); + + // check initialization + manager = createFairShuffleVertexManager(conf, mockContext, + fairRoutingType, 1000 * MB, 0.001f, 0.001f); + manager.onVertexStarted(emptyCompletions); + Assert.assertTrue(manager.bipartiteSources == 1); + + manager.onVertexStateUpdated(new VertexStateUpdate(r1, + VertexState.CONFIGURED)); + manager.onVertexStateUpdated(new VertexStateUpdate(m2, + VertexState.CONFIGURED)); + + Assert.assertEquals(numOfTasksInDestination, + manager.pendingTasks.size()); // no tasks scheduled + Assert.assertEquals(numOfTasksInr1, + manager.totalNumBipartiteSourceTasks); + Assert.assertEquals(0, manager.numBipartiteSourceTasksCompleted); + + //Send an event for r1. + manager.onSourceTaskCompleted(createTaskAttemptIdentifier(r1, 0)); + Assert.assertTrue(manager.pendingTasks.size() == numOfTasksInDestination); // no tasks scheduled + Assert.assertTrue(manager.totalNumBipartiteSourceTasks == numOfTasksInr1); + + long[] sizes = new long[]{(50 * MB), (200 * MB), (500 * MB)}; + VertexManagerEvent vmEvent = getVertexManagerEvent(sizes, 800 * MB, + r1, true); + manager.onVertexManagerEventReceived(vmEvent); //send VM event + + //stats from another task + sizes = new long[]{(60 * MB), (300 * MB), (600 * MB)}; + vmEvent = getVertexManagerEvent(sizes, 1200 * MB, r1, true); + manager.onVertexManagerEventReceived(vmEvent); //send VM event + + //Send an event for m2. + manager.onSourceTaskCompleted(createTaskAttemptIdentifier(m2, 0)); + Assert.assertTrue(manager.pendingTasks.size() == numOfTasksInDestination); // no tasks scheduled + Assert.assertTrue(manager.totalNumBipartiteSourceTasks == numOfTasksInr1); + + //Send an event for m3. + manager.onVertexStateUpdated(new VertexStateUpdate(m3, VertexState.CONFIGURED)); + manager.onSourceTaskCompleted(createTaskAttemptIdentifier(m3, 0)); + Assert.assertTrue(manager.pendingTasks.size() == 0); // all tasks scheduled + Assert.assertTrue(scheduledTasks.size() == expectedScheduledTasks); + + Assert.assertEquals(1, newEdgeManagers.size()); + EdgeManagerPluginOnDemand edgeManager = + (EdgeManagerPluginOnDemand)newEdgeManagers.values().iterator().next(); + // For each source task, there are 3 outputs, + // the same as original number of partitions. + for (int i = 0; i < numOfTasksInr1; i++) { + Assert.assertEquals(numOfTasksInDestination, + edgeManager.getNumSourceTaskPhysicalOutputs(0)); + } + + for (int sourceTaskIndex = 0; sourceTaskIndex < numOfTasksInr1; + sourceTaskIndex++) { + Assert.assertEquals(expectedNumDestinationConsumerTasks, + edgeManager.getNumDestinationConsumerTasks(sourceTaskIndex)); + } + } + + private static FairShuffleVertexManager createManager(Configuration conf, + VertexManagerPluginContext context, Float min, Float max) { + return createManager(conf, context, true, 1000l * MB, min, max); + } + + private static FairShuffleVertexManager createManager(Configuration conf, + VertexManagerPluginContext context, + Boolean enableAutoParallelism, Long desiredTaskInputSize, Float min, + Float max) { + return (FairShuffleVertexManager)TestShuffleVertexManagerBase.createManager( + FairShuffleVertexManager.class, conf, context, enableAutoParallelism, + desiredTaskInputSize, min, max); + } +}
