TEZ-2193. Check returned value from EdgeManagerPlugin before using it (zjffdu)
Project: http://git-wip-us.apache.org/repos/asf/tez/repo Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/f2e8e01f Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/f2e8e01f Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/f2e8e01f Branch: refs/heads/TEZ-2003 Commit: f2e8e01f45a05000d593221bba08802d60db6894 Parents: 685f6f0 Author: Jeff Zhang <[email protected]> Authored: Tue Mar 17 10:11:05 2015 +0800 Committer: Jeff Zhang <[email protected]> Committed: Tue Mar 17 10:11:05 2015 +0800 ---------------------------------------------------------------------- CHANGES.txt | 1 + .../org/apache/tez/dag/app/dag/impl/Edge.java | 20 +- .../app/dag/impl/ScatterGatherEdgeManager.java | 8 +- .../apache/tez/dag/app/dag/impl/TestEdge.java | 240 +++++++++++++++++++ 4 files changed, 265 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/tez/blob/f2e8e01f/CHANGES.txt ---------------------------------------------------------------------- diff --git a/CHANGES.txt b/CHANGES.txt index 371b87e..0afd2b3 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -244,6 +244,7 @@ TEZ-UI CHANGES (TEZ-8): Release 0.5.4: Unreleased ALL CHANGES: + TEZ-2193. Check returned value from EdgeManagerPlugin before using it TEZ-2133. Secured Impersonation: Failed to delete tez scratch data dir TEZ-2058. Flaky test: TestTezJobs::testInvalidQueueSubmission. TEZ-2037. Should log TaskAttemptFinishedEvent if taskattempt is recovered to KILLED. http://git-wip-us.apache.org/repos/asf/tez/blob/f2e8e01f/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/Edge.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/Edge.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/Edge.java index aeb94d7..d6f4477 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/Edge.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/Edge.java @@ -209,9 +209,13 @@ public class Edge { Preconditions.checkState(edgeManager != null, "Edge Manager must be initialized by this time"); try { + int physicalInputCount = edgeManager.getNumDestinationTaskPhysicalInputs(destinationTaskIndex); + Preconditions.checkArgument(physicalInputCount >= 0, + "PhysicalInputCount should not be negative, " + + "physicalInputCount=" + physicalInputCount); return new InputSpec(sourceVertex.getName(), edgeProperty.getEdgeDestination(), - edgeManager.getNumDestinationTaskPhysicalInputs(destinationTaskIndex)); + physicalInputCount); } catch (Exception e) { throw new AMUserCodeException(Source.EdgeManager, "Fail to getDestinationSpec, destinationTaskIndex=" @@ -223,9 +227,13 @@ public class Edge { Preconditions.checkState(edgeManager != null, "Edge Manager must be initialized by this time"); try { + int physicalOutputCount = edgeManager.getNumSourceTaskPhysicalOutputs( + sourceTaskIndex); + Preconditions.checkArgument(physicalOutputCount >= 0, + "PhysicalOutputCount should not be negative," + + "physicalOutputCount=" + physicalOutputCount); return new OutputSpec(destinationVertex.getName(), - edgeProperty.getEdgeSource(), edgeManager.getNumSourceTaskPhysicalOutputs( - sourceTaskIndex)); + edgeProperty.getEdgeSource(), physicalOutputCount); } catch (Exception e) { throw new AMUserCodeException(Source.EdgeManager, "Fail to getSourceSpec, sourceTaskIndex=" @@ -265,8 +273,14 @@ public class Edge { try { srcTaskIndex = edgeManager.routeInputErrorEventToSource(event, destTaskIndex, event.getIndex()); + Preconditions.checkArgument(srcTaskIndex >= 0, + "SourceTaskIndex should not be negative," + + "srcTaskIndex=" + srcTaskIndex); numConsumers = edgeManager.getNumDestinationConsumerTasks( srcTaskIndex); + Preconditions.checkArgument(numConsumers > 0, + "ConsumerTaskNum must be positive," + + "numConsumers=" + numConsumers); } catch (Exception e) { throw new AMUserCodeException(Source.EdgeManager, "Fail to sendTezEventToSourceTasks, " http://git-wip-us.apache.org/repos/asf/tez/blob/f2e8e01f/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/ScatterGatherEdgeManager.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/ScatterGatherEdgeManager.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/ScatterGatherEdgeManager.java index a8ee795..e2608cd 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/ScatterGatherEdgeManager.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/ScatterGatherEdgeManager.java @@ -27,6 +27,8 @@ import org.apache.tez.dag.api.EdgeManagerPluginContext; import org.apache.tez.runtime.api.events.DataMovementEvent; import org.apache.tez.runtime.api.events.InputReadErrorEvent; +import com.google.common.base.Preconditions; + public class ScatterGatherEdgeManager extends EdgeManagerPlugin { public ScatterGatherEdgeManager(EdgeManagerPluginContext context) { @@ -35,6 +37,7 @@ public class ScatterGatherEdgeManager extends EdgeManagerPlugin { @Override public void initialize() { + } @Override @@ -44,7 +47,10 @@ public class ScatterGatherEdgeManager extends EdgeManagerPlugin { @Override public int getNumSourceTaskPhysicalOutputs(int sourceTaskIndex) { - return getContext().getDestinationVertexNumTasks(); + int physicalOutputs = getContext().getDestinationVertexNumTasks(); + Preconditions.checkArgument(physicalOutputs >= 0, + "ScatteGather edge manager must have destination vertex task parallelism specified"); + return physicalOutputs; } @Override http://git-wip-us.apache.org/repos/asf/tez/blob/f2e8e01f/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestEdge.java ---------------------------------------------------------------------- diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestEdge.java b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestEdge.java index a607f1b..5718b17 100644 --- a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestEdge.java +++ b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestEdge.java @@ -28,6 +28,11 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.io.ByteArrayOutputStream; +import java.io.DataInput; +import java.io.DataOutput; +import java.io.DataOutputStream; +import java.io.IOException; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Collection; @@ -36,14 +41,19 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import org.apache.hadoop.io.DataInputByteBuffer; +import org.apache.hadoop.io.Writable; import org.apache.hadoop.yarn.event.EventHandler; +import org.apache.tez.dag.api.EdgeManagerPlugin; import org.apache.tez.dag.api.EdgeManagerPluginContext; +import org.apache.tez.dag.api.EdgeManagerPluginDescriptor; import org.apache.tez.dag.api.EdgeProperty; import org.apache.tez.dag.api.EdgeProperty.DataMovementType; import org.apache.tez.dag.api.EdgeProperty.DataSourceType; import org.apache.tez.dag.api.EdgeProperty.SchedulingType; import org.apache.tez.dag.api.InputDescriptor; import org.apache.tez.dag.api.OutputDescriptor; +import org.apache.tez.dag.api.UserPayload; import org.apache.tez.dag.app.dag.Task; import org.apache.tez.dag.app.dag.Vertex; import org.apache.tez.dag.records.TezDAGID; @@ -52,6 +62,7 @@ import org.apache.tez.dag.records.TezTaskID; import org.apache.tez.dag.records.TezVertexID; import org.apache.tez.runtime.api.events.CompositeDataMovementEvent; import org.apache.tez.runtime.api.events.DataMovementEvent; +import org.apache.tez.runtime.api.events.InputReadErrorEvent; import org.apache.tez.runtime.api.impl.EventMetaData; import org.apache.tez.runtime.api.impl.EventMetaData.EventProducerConsumerType; import org.apache.tez.runtime.api.impl.TezEvent; @@ -93,6 +104,27 @@ public class TestEdge { .get(0).intValue()); } + @Test(timeout = 5000) + public void testScatterGatherManager() { + EdgeManagerPluginContext mockContext = mock(EdgeManagerPluginContext.class); + when(mockContext.getSourceVertexName()).thenReturn("Source"); + when(mockContext.getDestinationVertexName()).thenReturn("Destination"); + ScatterGatherEdgeManager manager = new ScatterGatherEdgeManager(mockContext); + manager.initialize(); + + when(mockContext.getDestinationVertexNumTasks()).thenReturn(-1); + try { + manager.getNumSourceTaskPhysicalOutputs(0); + Assert.fail(); + } catch (IllegalArgumentException e) { + e.printStackTrace(); + Assert.assertTrue(e.getMessage() + .contains("ScatteGather edge manager must have destination vertex task parallelism specified")); + } + when(mockContext.getDestinationVertexNumTasks()).thenReturn(0); + manager.getNumSourceTaskPhysicalOutputs(0); + } + @SuppressWarnings({ "rawtypes" }) @Test (timeout = 5000) public void testCompositeEventHandling() throws AMUserCodeException { @@ -204,4 +236,212 @@ public class TestEdge { TezTaskAttemptID taskAttemptID = TezTaskAttemptID.getInstance(taskID, taId); return taskAttemptID; } + + @Test(timeout = 5000) + public void testInvalidPhysicalInputCount() throws Exception { + EventHandler mockEventHandler = mock(EventHandler.class); + Edge edge = new Edge(EdgeProperty.create( + EdgeManagerPluginDescriptor.create(CustomEdgeManagerWithInvalidReturnValue.class.getName()) + .setUserPayload(new CustomEdgeManagerWithInvalidReturnValue.EdgeManagerConfig(-1,1,1,1).toUserPayload()), + DataSourceType.PERSISTED, + SchedulingType.SEQUENTIAL, + OutputDescriptor.create(""), + InputDescriptor.create("")), mockEventHandler); + TezVertexID v1Id = createVertexID(1); + TezVertexID v2Id = createVertexID(2); + edge.setSourceVertex(mockVertex("v1", v1Id, new LinkedHashMap<TezTaskID, Task>())); + edge.setDestinationVertex(mockVertex("v2", v2Id, new LinkedHashMap<TezTaskID, Task>())); + edge.initialize(); + try { + edge.getDestinationSpec(0); + Assert.fail(); + } catch (AMUserCodeException e) { + e.printStackTrace(); + assertTrue(e.getCause().getMessage().contains("PhysicalInputCount should not be negative")); + } + } + + @Test(timeout = 5000) + public void testInvalidPhysicalOutputCount() throws Exception { + EventHandler mockEventHandler = mock(EventHandler.class); + Edge edge = new Edge(EdgeProperty.create( + EdgeManagerPluginDescriptor.create(CustomEdgeManagerWithInvalidReturnValue.class.getName()) + .setUserPayload(new CustomEdgeManagerWithInvalidReturnValue.EdgeManagerConfig(1,-1,1,1).toUserPayload()), + DataSourceType.PERSISTED, + SchedulingType.SEQUENTIAL, + OutputDescriptor.create(""), + InputDescriptor.create("")), mockEventHandler); + TezVertexID v1Id = createVertexID(1); + TezVertexID v2Id = createVertexID(2); + edge.setSourceVertex(mockVertex("v1", v1Id, new LinkedHashMap<TezTaskID, Task>())); + edge.setDestinationVertex(mockVertex("v2", v2Id, new LinkedHashMap<TezTaskID, Task>())); + edge.initialize(); + try { + edge.getSourceSpec(0); + Assert.fail(); + } catch (AMUserCodeException e) { + e.printStackTrace(); + assertTrue(e.getCause().getMessage().contains("PhysicalOutputCount should not be negative")); + } + } + + @Test(timeout = 5000) + public void testInvalidConsumerNumber() throws Exception { + EventHandler mockEventHandler = mock(EventHandler.class); + Edge edge = new Edge(EdgeProperty.create( + EdgeManagerPluginDescriptor.create(CustomEdgeManagerWithInvalidReturnValue.class.getName()) + .setUserPayload(new CustomEdgeManagerWithInvalidReturnValue.EdgeManagerConfig(1,1,0,1).toUserPayload()), + DataSourceType.PERSISTED, + SchedulingType.SEQUENTIAL, + OutputDescriptor.create(""), + InputDescriptor.create("")), mockEventHandler); + TezVertexID v1Id = createVertexID(1); + TezVertexID v2Id = createVertexID(2); + edge.setSourceVertex(mockVertex("v1", v1Id, new LinkedHashMap<TezTaskID, Task>())); + edge.setDestinationVertex(mockVertex("v2", v2Id, new LinkedHashMap<TezTaskID, Task>())); + edge.initialize(); + try { + TezEvent ireEvent = new TezEvent(InputReadErrorEvent.create("diag", 0, 1), + new EventMetaData(EventProducerConsumerType.INPUT, "v2", "v1", + TezTaskAttemptID.getInstance(TezTaskID.getInstance(v2Id, 1), 1))); + edge.sendTezEventToSourceTasks(ireEvent); + Assert.fail(); + } catch (AMUserCodeException e) { + e.printStackTrace(); + assertTrue(e.getCause().getMessage().contains("ConsumerTaskNum must be positive")); + } + } + + @Test(timeout = 5000) + public void testInvalidSourceTaskIndex() throws Exception { + EventHandler mockEventHandler = mock(EventHandler.class); + Edge edge = new Edge(EdgeProperty.create( + EdgeManagerPluginDescriptor.create(CustomEdgeManagerWithInvalidReturnValue.class.getName()) + .setUserPayload(new CustomEdgeManagerWithInvalidReturnValue.EdgeManagerConfig(1,1,1,-1).toUserPayload()), + DataSourceType.PERSISTED, + SchedulingType.SEQUENTIAL, + OutputDescriptor.create(""), + InputDescriptor.create("")), mockEventHandler); + TezVertexID v1Id = createVertexID(1); + TezVertexID v2Id = createVertexID(2); + edge.setSourceVertex(mockVertex("v1", v1Id, new LinkedHashMap<TezTaskID, Task>())); + edge.setDestinationVertex(mockVertex("v2", v2Id, new LinkedHashMap<TezTaskID, Task>())); + edge.initialize(); + try { + TezEvent ireEvent = new TezEvent(InputReadErrorEvent.create("diag", 0, 1), + new EventMetaData(EventProducerConsumerType.INPUT, "v2", "v1", + TezTaskAttemptID.getInstance(TezTaskID.getInstance(v2Id, 1), 1))); + edge.sendTezEventToSourceTasks(ireEvent); + Assert.fail(); + } catch (AMUserCodeException e) { + e.printStackTrace(); + assertTrue(e.getCause().getMessage().contains("SourceTaskIndex should not be negative")); + } + } + + public static class CustomEdgeManagerWithInvalidReturnValue extends EdgeManagerPlugin { + + public static class EdgeManagerConfig implements Writable { + int physicalInput = 1 ; + int physicalOutput = 1; + int consumerNumber = 1; + int sourceTaskIndex = 1; + + public EdgeManagerConfig() { + + } + + public EdgeManagerConfig(int physicalInput, int physicalOutput, + int consumerNumber, int sourceTaskIndex) { + super(); + this.physicalInput = physicalInput; + this.physicalOutput = physicalOutput; + this.consumerNumber = consumerNumber; + this.sourceTaskIndex = sourceTaskIndex; + } + + public UserPayload toUserPayload() throws IOException { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(bos); + write(out); + return UserPayload.create(ByteBuffer.wrap(bos.toByteArray())); + } + + public static EdgeManagerConfig fromUserPayload(UserPayload payload) + throws IOException { + EdgeManagerConfig emConf = new EdgeManagerConfig(); + DataInputByteBuffer in = new DataInputByteBuffer(); + in.reset(payload.getPayload()); + emConf.readFields(in); + return emConf; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(physicalInput); + out.writeInt(physicalOutput); + out.writeInt(consumerNumber); + out.writeInt(sourceTaskIndex); + } + + @Override + public void readFields(DataInput in) throws IOException { + physicalInput = in.readInt(); + physicalOutput = in.readInt(); + consumerNumber = in.readInt(); + sourceTaskIndex = in.readInt(); + } + } + + EdgeManagerConfig emConf; + + public CustomEdgeManagerWithInvalidReturnValue( + EdgeManagerPluginContext context) { + super(context); + } + + @Override + public void initialize() throws Exception { + emConf = EdgeManagerConfig.fromUserPayload(getContext().getUserPayload()); + } + + @Override + public int getNumDestinationTaskPhysicalInputs(int destinationTaskIndex) + throws Exception { + return emConf.physicalInput; + } + + @Override + public int getNumSourceTaskPhysicalOutputs(int sourceTaskIndex) + throws Exception { + return emConf.physicalOutput; + } + + @Override + public void routeDataMovementEventToDestination(DataMovementEvent event, + int sourceTaskIndex, int sourceOutputIndex, + Map<Integer, List<Integer>> destinationTaskAndInputIndices) + throws Exception { + } + + @Override + public void routeInputSourceTaskFailedEventToDestination( + int sourceTaskIndex, + Map<Integer, List<Integer>> destinationTaskAndInputIndices) + throws Exception { + } + + @Override + public int getNumDestinationConsumerTasks(int sourceTaskIndex) + throws Exception { + return emConf.consumerNumber; + } + + @Override + public int routeInputErrorEventToSource(InputReadErrorEvent event, + int destinationTaskIndex, int destinationFailedInputIndex) + throws Exception { + return emConf.sourceTaskIndex; + } + } }
