YARN-8882. [YARN-8851] Add a shared device mapping manager (scheduler) for 
device plugins. (Zhankun Tang via wangda)

Change-Id: I9435136642c3d556971a357bf687f69df90bb45e

Project: http://git-wip-us.apache.org/repos/asf/hadoop/repo
Commit: http://git-wip-us.apache.org/repos/asf/hadoop/commit/579ef4be
Tree: http://git-wip-us.apache.org/repos/asf/hadoop/tree/579ef4be
Diff: http://git-wip-us.apache.org/repos/asf/hadoop/diff/579ef4be

Branch: refs/heads/trunk
Commit: 579ef4be063745c5211127eca83a393ceddc8b79
Parents: 9de8e8d
Author: Wangda Tan <wan...@apache.org>
Authored: Wed Nov 28 14:09:52 2018 -0800
Committer: Wangda Tan <wan...@apache.org>
Committed: Wed Nov 28 14:09:52 2018 -0800

 .../resourceplugin/ResourcePluginManager.java   |  14 +-
 .../deviceframework/DeviceMappingManager.java   | 324 ++++++++++++++++
 .../deviceframework/DevicePluginAdapter.java    |  20 +-
 .../DeviceResourceHandlerImpl.java              | 145 +++++++
 .../TestDeviceMappingManager.java               | 366 +++++++++++++++++
 .../TestDevicePluginAdapter.java                | 388 ++++++++++++++++++-
 6 files changed, 1245 insertions(+), 12 deletions(-)

diff --git 
index 9741b12..6dfe817 100644
@@ -30,6 +30,7 @@ import org.apache.hadoop.yarn.exceptions.YarnRuntimeException;
 import org.apache.hadoop.yarn.server.nodemanager.Context;
 import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
@@ -52,12 +53,13 @@ import static 
 public class ResourcePluginManager {
   private static final Logger LOG =
-  private static final Set<String> SUPPORTED_RESOURCE_PLUGINS = 
-      GPU_URI, FPGA_URI);
+  private static final Set<String> SUPPORTED_RESOURCE_PLUGINS =
+      ImmutableSet.of(GPU_URI, FPGA_URI);
   private Map<String, ResourcePlugin> configuredPlugins =
+  private DeviceMappingManager deviceMappingManager = null;
   public synchronized void initialize(Context context)
       throws YarnException, ClassNotFoundException {
@@ -123,7 +125,7 @@ public class ResourcePluginManager {
       throws YarnRuntimeException, ClassNotFoundException {
     LOG.info("The pluggable device framework enabled," +
         "trying to load the vendor plugins");
+    deviceMappingManager = new DeviceMappingManager(context);
     String[] pluginClassNames = configuration.getStrings(
     if (null == pluginClassNames) {
@@ -174,7 +176,7 @@ public class ResourcePluginManager {
       DevicePluginAdapter pluginAdapter = new DevicePluginAdapter(
-          resourceName, dpInstance);
+          resourceName, dpInstance, deviceMappingManager);
       LOG.info("Adapter of {} created. Initializing..", pluginClassName);
       try {
@@ -235,6 +237,10 @@ public class ResourcePluginManager {
     return true;
+  public DeviceMappingManager getDeviceMappingManager() {
+    return deviceMappingManager;
+  }
   public synchronized void cleanup() throws YarnException {
     for (ResourcePlugin plugin : configuredPlugins.values()) {

diff --git 
new file mode 100644
index 0000000..b8b711b
--- /dev/null
@@ -0,0 +1,324 @@
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Sets;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.util.StringUtils;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.api.records.Resource;
+import org.apache.hadoop.yarn.exceptions.ResourceNotFoundException;
+import org.apache.hadoop.yarn.server.nodemanager.Context;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.TreeSet;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.Collections;
+import java.util.concurrent.ConcurrentHashMap;
+ * Schedule device resource based on requirements and do book keeping
+ * It holds all device type resource and can do scheduling as a default
+ * scheduler.
+ * */
+public class DeviceMappingManager {
+  static final Log LOG = LogFactory.getLog(DeviceMappingManager.class);
+  private Context nmContext;
+  private static final int WAIT_MS_PER_LOOP = 1000;
+  /**
+   * Hold all type of devices.
+   * key is the device resource name
+   * value is a sorted set of {@link Device}
+   * */
+  private Map<String, Set<Device>> allAllowedDevices =
+      new ConcurrentHashMap<>();
+  /**
+   * Hold used devices.
+   * key is the device resource name
+   * value is a sorted map of {@link Device} and {@link ContainerId} pairs
+   * */
+  private Map<String, Map<Device, ContainerId>> allUsedDevices =
+      new ConcurrentHashMap<>();
+  public DeviceMappingManager(Context context) {
+    nmContext = context;
+  }
+  @VisibleForTesting
+  public Map<String, Set<Device>> getAllAllowedDevices() {
+    return allAllowedDevices;
+  }
+  @VisibleForTesting
+  public Map<String, Map<Device, ContainerId>> getAllUsedDevices() {
+    return allUsedDevices;
+  }
+  public synchronized void addDeviceSet(String resourceName,
+      Set<Device> deviceSet) {
+    LOG.info("Adding new resource: " + "type:"
+        + resourceName + "," + deviceSet);
+    allAllowedDevices.put(resourceName, new TreeSet<>(deviceSet));
+    allUsedDevices.put(resourceName, new TreeMap<>());
+  }
+  public DeviceAllocation assignDevices(String resourceName,
+      Container container)
+      throws ResourceHandlerException {
+    DeviceAllocation allocation = internalAssignDevices(resourceName,
+        container);
+    // Wait for a maximum of 120 seconds if no available Devices are there
+    // which are yet to be released.
+    final int timeoutMsecs = 120 * WAIT_MS_PER_LOOP;
+    int timeWaiting = 0;
+    while (allocation == null) {
+      if (timeWaiting >= timeoutMsecs) {
+        break;
+      }
+      // Sleep for 1 sec to ensure there are some free devices which are
+      // getting released.
+      try {
+        LOG.info("Container : " + container.getContainerId()
+            + " is waiting for free " + resourceName + " devices.");
+        Thread.sleep(WAIT_MS_PER_LOOP);
+        timeWaiting += WAIT_MS_PER_LOOP;
+        allocation = internalAssignDevices(resourceName, container);
+      } catch (InterruptedException e) {
+        // On any interrupt, break the loop and continue execution.
+        break;
+      }
+    }
+    if (allocation == null) {
+      String message = "Could not get valid " + resourceName
+          + " device for container '" + container.getContainerId()
+          + "' as some other containers might not releasing them.";
+      LOG.warn(message);
+      throw new ResourceHandlerException(message);
+    }
+    return allocation;
+  }
+  private synchronized DeviceAllocation internalAssignDevices(
+      String resourceName, Container container)
+      throws ResourceHandlerException {
+    Resource requestedResource = container.getResource();
+    ContainerId containerId = container.getContainerId();
+    int requestedDeviceCount = getRequestedDeviceCount(resourceName,
+        requestedResource);
+    LOG.debug("Try allocating " + requestedDeviceCount
+        + " " + resourceName);
+    // Assign devices to container if requested some.
+    if (requestedDeviceCount > 0) {
+      if (requestedDeviceCount > getAvailableDevices(resourceName)) {
+        // If there are some devices which are getting released, wait for few
+        // seconds to get it.
+        if (requestedDeviceCount <= getReleasingDevices(resourceName)
+            + getAvailableDevices(resourceName)) {
+          return null;
+        }
+      }
+      int availableDeviceCount = getAvailableDevices(resourceName);
+      if (requestedDeviceCount > availableDeviceCount) {
+        throw new ResourceHandlerException("Failed to find enough "
+            + resourceName
+            + ", requestor=" + containerId
+            + ", #Requested=" + requestedDeviceCount + ", #available="
+            + availableDeviceCount);
+      }
+      Set<Device> assignedDevices = new TreeSet<>();
+      Map<Device, ContainerId> usedDevices = allUsedDevices.get(resourceName);
+      Set<Device> allowedDevices = allAllowedDevices.get(resourceName);
+      defaultScheduleAction(allowedDevices, usedDevices,
+          assignedDevices, containerId, requestedDeviceCount);
+      // Record in state store if we allocated anything
+      if (!assignedDevices.isEmpty()) {
+        try {
+          // Update state store.
+          nmContext.getNMStateStore().storeAssignedResources(container,
+              resourceName,
+              new ArrayList<>(assignedDevices));
+        } catch (IOException e) {
+          cleanupAssignedDevices(resourceName, containerId);
+          throw new ResourceHandlerException(e);
+        }
+      }
+      return new DeviceAllocation(resourceName, assignedDevices,
+          Sets.difference(allowedDevices, assignedDevices));
+    }
+    return new DeviceAllocation(resourceName, null,
+        allAllowedDevices.get(resourceName));
+  }
+  public synchronized void recoverAssignedDevices(String resourceName,
+      ContainerId containerId)
+      throws ResourceHandlerException {
+    Container c = nmContext.getContainers().get(containerId);
+    Map<Device, ContainerId> usedDevices = allUsedDevices.get(resourceName);
+    Set<Device> allowedDevices = allAllowedDevices.get(resourceName);
+    if (null == c) {
+      throw new ResourceHandlerException(
+          "This shouldn't happen, cannot find container with id="
+              + containerId);
+    }
+    for (Serializable deviceSerializable : c.getResourceMappings()
+        .getAssignedResources(resourceName)) {
+      if (!(deviceSerializable instanceof Device)) {
+        throw new ResourceHandlerException(
+            "Trying to recover device id, however it"
+                + " is not Device instance, this shouldn't happen");
+      }
+      Device device = (Device) deviceSerializable;
+      // Make sure it is in allowed device.
+      if (!allowedDevices.contains(device)) {
+        throw new ResourceHandlerException(
+            "Try to recover device = " + device
+                + " however it is not in allowed device list:" + StringUtils
+                .join(",", allowedDevices));
+      }
+      // Make sure it is not occupied by anybody else
+      if (usedDevices.containsKey(device)) {
+        throw new ResourceHandlerException(
+            "Try to recover device id = " + device
+                + " however it is already assigned to container="
+                + usedDevices.get(device)
+                + ", please double check what happened.");
+      }
+      usedDevices.put(device, containerId);
+    }
+  }
+  public synchronized void cleanupAssignedDevices(String resourceName,
+      ContainerId containerId) {
+    Iterator<Map.Entry<Device, ContainerId>> iter =
+        allUsedDevices.get(resourceName).entrySet().iterator();
+    while (iter.hasNext()) {
+      if (iter.next().getValue().equals(containerId)) {
+        iter.remove();
+      }
+    }
+  }
+  public static int getRequestedDeviceCount(String resourceName,
+      Resource requestedResource) {
+    try {
+      return Long.valueOf(requestedResource.getResourceValue(
+          resourceName)).intValue();
+    } catch (ResourceNotFoundException e) {
+      return 0;
+    }
+  }
+  public int getAvailableDevices(String resourceName) {
+    return allAllowedDevices.get(resourceName).size()
+        - allUsedDevices.get(resourceName).size();
+  }
+  private long getReleasingDevices(String resourceName) {
+    long releasingDevices = 0;
+    Map<Device, ContainerId> used = allUsedDevices.get(resourceName);
+    Iterator<Map.Entry<Device, ContainerId>> iter = used.entrySet()
+        .iterator();
+    while (iter.hasNext()) {
+      ContainerId containerId = iter.next().getValue();
+      Container container = nmContext.getContainers().get(containerId);
+      if (container != null) {
+        if (container.isContainerInFinalStates()) {
+          releasingDevices = releasingDevices + container.getResource()
+              .getResourceInformation(resourceName).getValue();
+        }
+      }
+    }
+    return releasingDevices;
+  }
+  // default scheduling logic
+  private void defaultScheduleAction(Set<Device> allowed,
+      Map<Device, ContainerId> used, Set<Device> assigned,
+      ContainerId containerId, int count) {
+    LOG.debug("Using default scheduler. Allowed:" + allowed
+        + ",Used:" + used + ", containerId:" + containerId);
+    for (Device device : allowed) {
+      if (!used.containsKey(device)) {
+        used.put(device, containerId);
+        assigned.add(device);
+        if (assigned.size() == count) {
+          return;
+        }
+      }
+    } // end for
+  }
+  static class DeviceAllocation {
+    private String resourceName;
+    private Set<Device> allowed = Collections.emptySet();
+    private Set<Device> denied = Collections.emptySet();
+    DeviceAllocation(String resName, Set<Device> a,
+        Set<Device> d) {
+      this.resourceName = resName;
+      if (a != null) {
+        this.allowed = ImmutableSet.copyOf(a);
+      }
+      if (d != null) {
+        this.denied = ImmutableSet.copyOf(d);
+      }
+    }
+    public Set<Device> getAllowed() {
+      return allowed;
+    }
+    @Override
+    public String toString() {
+      return "ResourceType: " + resourceName
+          + ", Allowed Devices: " + allowed
+          + ", Denied Devices: " + denied;
+    }
+  }

diff --git 
index 18a6992..1636cb8 100644
@@ -33,7 +33,7 @@ import 
- * The {@link DevicePluginAdapter} will adapt existing hooks
+ * The {@link DevicePluginAdapter} will adapt existing hooks.
  * into vendor plugin's logic.
  * It decouples the vendor plugin from YARN's device framework
@@ -43,13 +43,21 @@ public class DevicePluginAdapter implements ResourcePlugin {
   private final String resourceName;
   private final DevicePlugin devicePlugin;
+  private DeviceMappingManager deviceMappingManager;
   private DeviceResourceUpdaterImpl deviceResourceUpdater;
+  private DeviceResourceHandlerImpl deviceResourceHandler;
-  public DevicePluginAdapter(String name, DevicePlugin dp) {
+  public DevicePluginAdapter(String name, DevicePlugin dp,
+      DeviceMappingManager dmm) {
+    deviceMappingManager = dmm;
     resourceName = name;
     devicePlugin = dp;
+  public DeviceMappingManager getDeviceMappingManager() {
+    return deviceMappingManager;
+  }
   public void initialize(Context context) throws YarnException {
     deviceResourceUpdater = new DeviceResourceUpdaterImpl(
@@ -62,7 +70,10 @@ public class DevicePluginAdapter implements ResourcePlugin {
   public ResourceHandler createResourceHandler(Context nmContext,
       CGroupsHandler cGroupsHandler,
       PrivilegedOperationExecutor privilegedOperationExecutor) {
-    return null;
+    this.deviceResourceHandler = new DeviceResourceHandlerImpl(resourceName,
+        devicePlugin, this, deviceMappingManager,
+        cGroupsHandler, privilegedOperationExecutor);
+    return deviceResourceHandler;
@@ -85,4 +96,7 @@ public class DevicePluginAdapter implements ResourcePlugin {
     return null;
+  public DeviceResourceHandlerImpl getDeviceResourceHandler() {
+    return deviceResourceHandler;
+  }

diff --git 
new file mode 100644
index 0000000..d33b8da
--- /dev/null
@@ -0,0 +1,145 @@
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
+import java.util.List;
+import java.util.Set;
+ * The Hooks into container lifecycle.
+ * Get device list from device plugin in {@code bootstrap}
+ * Assign devices for a container in {@code preStart}
+ * Restore statue in {@code reacquireContainer}
+ * Recycle devices from container in {@code postComplete}
+ * */
+public class DeviceResourceHandlerImpl implements ResourceHandler {
+  static final Log LOG = LogFactory.getLog(DeviceResourceHandlerImpl.class);
+  private String resourceName;
+  private DevicePlugin devicePlugin;
+  private DeviceMappingManager deviceMappingManager;
+  private CGroupsHandler cGroupsHandler;
+  private PrivilegedOperationExecutor privilegedOperationExecutor;
+  private DevicePluginAdapter devicePluginAdapter;
+  public DeviceResourceHandlerImpl(String reseName,
+      DevicePlugin devPlugin,
+      DevicePluginAdapter devPluginAdapter,
+      DeviceMappingManager devMappingManager,
+      CGroupsHandler cgHandler,
+      PrivilegedOperationExecutor operation) {
+    this.devicePluginAdapter = devPluginAdapter;
+    this.resourceName = reseName;
+    this.devicePlugin = devPlugin;
+    this.cGroupsHandler = cgHandler;
+    this.privilegedOperationExecutor = operation;
+    this.deviceMappingManager = devMappingManager;
+  }
+  @Override
+  public List<PrivilegedOperation> bootstrap(Configuration configuration)
+      throws ResourceHandlerException {
+    Set<Device> availableDevices = null;
+    try {
+      availableDevices = devicePlugin.getDevices();
+    } catch (Exception e) {
+      throw new ResourceHandlerException("Exception thrown from"
+          + " plugin's \"getDevices\"" + e.getMessage());
+    }
+    /**
+     * We won't fail the NM if plugin returns invalid value here.
+     * */
+    if (availableDevices == null) {
+      LOG.error("Bootstrap " + resourceName
+          + " failed. Null value got from plugin's getDevices method");
+      return null;
+    }
+    // Add device set. Here we trust the plugin's return value
+    deviceMappingManager.addDeviceSet(resourceName, availableDevices);
+    // TODO: Init cgroups
+    return null;
+  }
+  @Override
+  public synchronized List<PrivilegedOperation> preStart(Container container)
+      throws ResourceHandlerException {
+    String containerIdStr = container.getContainerId().toString();
+    DeviceMappingManager.DeviceAllocation allocation =
+        deviceMappingManager.assignDevices(resourceName, container);
+    LOG.debug("Allocated to "
+        + containerIdStr + ": " + allocation);
+    try {
+      devicePlugin.onDevicesAllocated(
+          allocation.getAllowed(), YarnRuntimeType.RUNTIME_DEFAULT);
+    } catch (Exception e) {
+      throw new ResourceHandlerException("Exception thrown from"
+          + " plugin's \"onDeviceAllocated\"" + e.getMessage());
+    }
+    // cgroups operation based on allocation
+    /**
+     * TODO: implement a general container-executor device module
+     * */
+    return null;
+  }
+  @Override
+  public synchronized List<PrivilegedOperation> reacquireContainer(
+      ContainerId containerId) throws ResourceHandlerException {
+    deviceMappingManager.recoverAssignedDevices(resourceName, containerId);
+    return null;
+  }
+  @Override
+  public List<PrivilegedOperation> updateContainer(Container container)
+      throws ResourceHandlerException {
+    return null;
+  }
+  @Override
+  public synchronized List<PrivilegedOperation> postComplete(
+      ContainerId containerId) throws ResourceHandlerException {
+    deviceMappingManager.cleanupAssignedDevices(resourceName, containerId);
+    return null;
+  }
+  @Override
+  public List<PrivilegedOperation> teardown()
+      throws ResourceHandlerException {
+    return null;
+  }

diff --git 
new file mode 100644
index 0000000..d69ab42
--- /dev/null
@@ -0,0 +1,366 @@
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
+import org.apache.hadoop.yarn.api.records.Resource;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.server.nodemanager.NodeManager;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
+import org.apache.hadoop.yarn.server.nodemanager.recovery.NMStateStoreService;
+import org.apache.hadoop.yarn.util.resource.ResourceUtils;
+import org.apache.hadoop.yarn.util.resource.TestResourceUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.util.TreeSet;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import static org.mockito.Mockito.*;
+ * Tests for DeviceMappingManager.
+ * Note that we test it under multi-threaded situation
+ * */
+public class TestDeviceMappingManager {
+  protected static final Logger LOG =
+      LoggerFactory.getLogger(TestDeviceMappingManager.class);
+  private String tempResourceTypesFile;
+  private DeviceMappingManager dmm;
+  private ExecutorService containerLauncher;
+  private Configuration conf;
+  @Before
+  public void setup() throws Exception {
+    // setup resource-types.xml
+    conf = new YarnConfiguration();
+    ResourceUtils.resetResourceTypes();
+    String resourceTypesFile = "resource-types-pluggable-devices.xml";
+    this.tempResourceTypesFile =
+        TestResourceUtils.setupResourceTypes(this.conf, resourceTypesFile);
+    NodeManager.NMContext context = mock(NodeManager.NMContext.class);
+    NMStateStoreService storeService = mock(NMStateStoreService.class);
+    when(context.getNMStateStore()).thenReturn(storeService);
+    doNothing().when(storeService).storeAssignedResources(isA(Container.class),
+        isA(String.class),
+        isA(ArrayList.class));
+    dmm = new DeviceMappingManager(context);
+    int deviceCount = 600;
+    TreeSet<Device> r = new TreeSet<>();
+    for (int i = 0; i < deviceCount; i++) {
+      r.add(Device.Builder.newInstance()
+          .setId(i)
+          .setDevPath("/dev/hdwA" + i)
+          .setMajorNumber(243)
+          .setMinorNumber(i)
+          .setBusID("0000:65:00." + i)
+          .setHealthy(true)
+          .build());
+    }
+    TreeSet<Device> r1 = new TreeSet<>();
+    for (int i = 0; i < deviceCount; i++) {
+      r1.add(Device.Builder.newInstance()
+          .setId(i)
+          .setDevPath("/dev/cmp" + i)
+          .setMajorNumber(100)
+          .setMinorNumber(i)
+          .setBusID("0000:11:00." + i)
+          .setHealthy(true)
+          .build());
+    }
+    dmm.addDeviceSet("cmpA.com/hdwA", r);
+    dmm.addDeviceSet("cmp.com/cmp", r1);
+    containerLauncher =
+        Executors.newFixedThreadPool(10);
+  }
+  @After
+  public void tearDown() throws IOException {
+    // cleanup resource-types.xml
+    File dest = new File(this.tempResourceTypesFile);
+    if (dest.exists()) {
+      boolean flag = dest.delete();
+    }
+  }
+  /**
+   * Simulate launch different containers requesting different resource.
+   * */
+  @Test
+  public void testAllocation()
+      throws InterruptedException, ResourceHandlerException {
+    int totalContainerCount = 100;
+    String resourceName1 = "cmpA.com/hdwA";
+    String resourceName2 = "cmp.com/cmp";
+    DeviceMappingManager dmmSpy = spy(dmm);
+    // generate a list of container
+    Map<String, Map<Container, Integer>> containerSet = new HashMap<>();
+    containerSet.put(resourceName1, new HashMap<>());
+    containerSet.put(resourceName2, new HashMap<>());
+    Long startTime = System.currentTimeMillis();
+    for (int i = 0; i < totalContainerCount; i++) {
+      // Random requeted device
+      int num = new Random().nextInt(5) + 1;
+      // Random requested resource type
+      String resourceName;
+      int seed = new Random().nextInt(5);
+      if (seed % 2 == 0) {
+        resourceName = resourceName1;
+      } else {
+        resourceName = resourceName2;
+      }
+      Container c = mockContainerWithDeviceRequest(i,
+          resourceName,
+          num, false);
+      containerSet.get(resourceName).put(c, num);
+      DeviceResourceHandlerImpl dri = new DeviceResourceHandlerImpl(
+          resourceName,
+          new MyTestPlugin(), null,
+          dmmSpy, null, null);
+      Future<Integer> f = containerLauncher.submit(new MyContainerLaunch(
+          dri, c, i, false));
+    }
+    containerLauncher.shutdown();
+    while (!containerLauncher.awaitTermination(10, TimeUnit.SECONDS)) {
+      LOG.info("Wait for the threads to finish");
+    }
+    Long endTime = System.currentTimeMillis();
+    LOG.info("Each container allocation spends roughly: {} ms",
+        (endTime - startTime)/totalContainerCount);
+    // Ensure invocation times
+    verify(dmmSpy, times(totalContainerCount)).assignDevices(
+        anyString(), any(Container.class));
+    // Ensure used devices' count for each type is correct
+    int totalAllocatedCount = 0;
+    Map<Device, ContainerId> used1 =
+        dmm.getAllUsedDevices().get(resourceName1);
+    Map<Device, ContainerId> used2 =
+        dmm.getAllUsedDevices().get(resourceName2);
+    for (Map.Entry<Container, Integer> entry :
+        containerSet.get(resourceName1).entrySet()) {
+      totalAllocatedCount += entry.getValue();
+    }
+    for (Map.Entry<Container, Integer> entry :
+        containerSet.get(resourceName2).entrySet()) {
+      totalAllocatedCount += entry.getValue();
+    }
+    Assert.assertEquals(totalAllocatedCount, used1.size() + used2.size());
+    // Ensure each container has correct devices
+    for (Map.Entry<Container, Integer> entry :
+        containerSet.get(resourceName1).entrySet()) {
+      int containerWanted = entry.getValue();
+      int actualAllocated = 0;
+      for (ContainerId cid : used1.values()) {
+        if (cid.equals(entry.getKey().getContainerId())) {
+          actualAllocated++;
+        }
+      }
+      Assert.assertEquals(containerWanted, actualAllocated);
+    }
+    for (Map.Entry<Container, Integer> entry :
+        containerSet.get(resourceName2).entrySet()) {
+      int containerWanted = entry.getValue();
+      int actualAllocated = 0;
+      for (ContainerId cid : used2.values()) {
+        if (cid.equals(entry.getKey().getContainerId())) {
+          actualAllocated++;
+        }
+      }
+      Assert.assertEquals(containerWanted, actualAllocated);
+    }
+  }
+  /**
+   * Simulate launch containers and cleanup.
+   * */
+  @Test
+  public void testAllocationAndCleanup()
+      throws InterruptedException, ResourceHandlerException, IOException {
+    int totalContainerCount = 10;
+    String resourceName1 = "cmpA.com/hdwA";
+    String resourceName2 = "cmp.com/cmp";
+    DeviceMappingManager dmmSpy = spy(dmm);
+    // generate a list of container
+    Map<String, Map<Container, Integer>> containerSet = new HashMap<>();
+    containerSet.put(resourceName1, new HashMap<>());
+    containerSet.put(resourceName2, new HashMap<>());
+    for (int i = 0; i < totalContainerCount; i++) {
+      // Random requeted device
+      int num = new Random().nextInt(5) + 1;
+      // Random requested resource type
+      String resourceName;
+      int seed = new Random().nextInt(5);
+      if (seed % 2 == 0) {
+        resourceName = resourceName1;
+      } else {
+        resourceName = resourceName2;
+      }
+      Container c = mockContainerWithDeviceRequest(i,
+          resourceName,
+          num, false);
+      containerSet.get(resourceName).put(c, num);
+      DeviceResourceHandlerImpl dri = new DeviceResourceHandlerImpl(
+          resourceName,
+          new MyTestPlugin(), null,
+          dmmSpy, null, null);
+      Future<Integer> f = containerLauncher.submit(new MyContainerLaunch(
+          dri, c, i, true));
+    }
+    containerLauncher.shutdown();
+    while (!containerLauncher.awaitTermination(10, TimeUnit.SECONDS)) {
+      LOG.info("Wait for the threads to finish");
+    }
+    // Ensure invocation times
+    verify(dmmSpy, times(totalContainerCount)).assignDevices(
+        anyString(), any(Container.class));
+    verify(dmmSpy, times(totalContainerCount)).cleanupAssignedDevices(
+        anyString(), any(ContainerId.class));
+    // Ensure all devices are back
+    Assert.assertEquals(0,
+        dmm.getAllUsedDevices().get(resourceName1).size());
+    Assert.assertEquals(0,
+        dmm.getAllUsedDevices().get(resourceName2).size());
+  }
+  private static Container mockContainerWithDeviceRequest(int id,
+      String resourceName,
+      int numDeviceRequest,
+      boolean dockerContainerEnabled) {
+    Container c = mock(Container.class);
+    when(c.getContainerId()).thenReturn(getContainerId(id));
+    Resource res = Resource.newInstance(1024, 1);
+    ResourceMappings resMapping = new ResourceMappings();
+    res.setResourceValue(resourceName, numDeviceRequest);
+    when(c.getResource()).thenReturn(res);
+    when(c.getResourceMappings()).thenReturn(resMapping);
+    ContainerLaunchContext clc = mock(ContainerLaunchContext.class);
+    Map<String, String> env = new HashMap<>();
+    if (dockerContainerEnabled) {
+      env.put(ContainerRuntimeConstants.ENV_CONTAINER_TYPE,
+          ContainerRuntimeConstants.CONTAINER_RUNTIME_DOCKER);
+    }
+    when(clc.getEnvironment()).thenReturn(env);
+    when(c.getLaunchContext()).thenReturn(clc);
+    return c;
+  }
+  private static ContainerId getContainerId(int id) {
+    return ContainerId.newContainerId(ApplicationAttemptId
+        .newInstance(ApplicationId.newInstance(1234L, 1), 1), id);
+  }
+  private static class MyContainerLaunch implements Callable<Integer> {
+    private DeviceResourceHandlerImpl deviceResourceHandler;
+    private Container container;
+    private boolean doCleanup;
+    private int cId;
+    MyContainerLaunch(DeviceResourceHandlerImpl dri,
+        Container c, int id, boolean cleanup) {
+      deviceResourceHandler = dri;
+      container = c;
+      doCleanup = cleanup;
+      cId = id;
+    }
+    @Override
+    public Integer call() throws Exception {
+      try {
+        deviceResourceHandler.preStart(container);
+        if (doCleanup) {
+          int seconds = new Random().nextInt(5);
+          LOG.info("sleep " + seconds);
+          Thread.sleep(seconds * 1000);
+          deviceResourceHandler.postComplete(getContainerId(cId));
+        }
+      } catch (ResourceHandlerException e) {
+        e.printStackTrace();
+      }
+      return 0;
+    }
+  }
+  private static class MyTestPlugin implements DevicePlugin {
+    private final static String RESOURCE_NAME = "abc";
+    @Override
+    public DeviceRegisterRequest getRegisterRequestInfo() {
+      return DeviceRegisterRequest.Builder.newInstance()
+          .setResourceName(RESOURCE_NAME).build();
+    }
+    @Override
+    public Set<Device> getDevices() {
+      TreeSet<Device> r = new TreeSet<>();
+      return r;
+    }
+    @Override
+    public DeviceRuntimeSpec onDevicesAllocated(Set<Device> allocatedDevices,
+        YarnRuntimeType yarnRuntime) throws Exception {
+      return null;
+    }
+    @Override
+    public void onDevicesReleased(Set<Device> releasedDevices) {
+    }
+  } // MyPlugin

diff --git 
index c938b83..2534a0a 100644
@@ -18,14 +18,35 @@
-import org.apache.hadoop.yarn.api.records.*;
+import org.apache.hadoop.service.ServiceOperations;
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
+import org.apache.hadoop.yarn.api.records.Resource;
 import org.apache.hadoop.yarn.conf.YarnConfiguration;
 import org.apache.hadoop.yarn.exceptions.YarnException;
-import org.apache.hadoop.yarn.server.nodemanager.*;
-import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.*;
+import org.apache.hadoop.yarn.server.nodemanager.Context;
+import org.apache.hadoop.yarn.server.nodemanager.NodeManager;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
+import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
+import org.apache.hadoop.yarn.server.nodemanager.recovery.NMStateStoreService;
 import org.apache.hadoop.yarn.util.resource.ResourceUtils;
 import org.apache.hadoop.yarn.util.resource.TestResourceUtils;
 import org.junit.After;
+import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 import org.slf4j.Logger;
@@ -33,13 +54,23 @@ import org.slf4j.LoggerFactory;
 import java.io.File;
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
 import java.util.Set;
 import java.util.TreeSet;
+import java.util.concurrent.ConcurrentHashMap;
+import static org.mockito.Matchers.isA;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doThrow;
  * Unit tests for DevicePluginAdapter.
@@ -52,6 +83,9 @@ public class TestDevicePluginAdapter {
   private YarnConfiguration conf;
   private String tempResourceTypesFile;
+  private CGroupsHandler mockCGroupsHandler;
+  private PrivilegedOperationExecutor mockPrivilegedExecutor;
+  private NodeManager nm;
   public void setup() throws Exception {
@@ -61,6 +95,8 @@ public class TestDevicePluginAdapter {
     String resourceTypesFile = "resource-types-pluggable-devices.xml";
     this.tempResourceTypesFile =
         TestResourceUtils.setupResourceTypes(this.conf, resourceTypesFile);
+    mockCGroupsHandler = mock(CGroupsHandler.class);
+    mockPrivilegedExecutor = mock(PrivilegedOperationExecutor.class);
@@ -70,19 +106,132 @@ public class TestDevicePluginAdapter {
     if (dest.exists()) {
+    if (nm != null) {
+      try {
+        ServiceOperations.stop(nm);
+      } catch (Throwable t) {
+        // ignore
+      }
+    }
+  }
+  /**
+   * Use the MyPlugin which implement {@code DevicePlugin}.
+   * Plugin's initialization is tested in TestResourcePluginManager
+   * */
+  @Test
+  public void testBasicWorkflow()
+      throws YarnException, IOException {
+    NodeManager.NMContext context = mock(NodeManager.NMContext.class);
+    NMStateStoreService storeService = mock(NMStateStoreService.class);
+    when(context.getNMStateStore()).thenReturn(storeService);
+    doNothing().when(storeService).storeAssignedResources(isA(Container.class),
+        isA(String.class),
+        isA(ArrayList.class));
+    // Init scheduler manager
+    DeviceMappingManager dmm = new DeviceMappingManager(context);
+    ResourcePluginManager rpm = mock(ResourcePluginManager.class);
+    when(rpm.getDeviceMappingManager()).thenReturn(dmm);
+    // Init an plugin
+    MyPlugin plugin = new MyPlugin();
+    MyPlugin spyPlugin = spy(plugin);
+    String resourceName = MyPlugin.RESOURCE_NAME;
+    // Init an adapter for the plugin
+    DevicePluginAdapter adapter = new DevicePluginAdapter(
+        resourceName,
+        spyPlugin, dmm);
+    // Bootstrap, adding device
+    adapter.initialize(context);
+    adapter.createResourceHandler(context,
+        mockCGroupsHandler, mockPrivilegedExecutor);
+    adapter.getDeviceResourceHandler().bootstrap(conf);
+    int size = dmm.getAvailableDevices(resourceName);
+    Assert.assertEquals(3, size);
+    // A container c1 requests 1 device
+    Container c1 = mockContainerWithDeviceRequest(0,
+        resourceName,
+        1, false);
+    // preStart
+    adapter.getDeviceResourceHandler().preStart(c1);
+    // check book keeping
+    Assert.assertEquals(2,
+        dmm.getAvailableDevices(resourceName));
+    Assert.assertEquals(1,
+        dmm.getAllUsedDevices().get(resourceName).size());
+    Assert.assertEquals(3,
+        dmm.getAllAllowedDevices().get(resourceName).size());
+    // postComplete
+    adapter.getDeviceResourceHandler().postComplete(getContainerId(0));
+    Assert.assertEquals(3,
+        dmm.getAvailableDevices(resourceName));
+    Assert.assertEquals(0,
+        dmm.getAllUsedDevices().get(resourceName).size());
+    Assert.assertEquals(3,
+        dmm.getAllAllowedDevices().get(resourceName).size());
+    // A container c2 requests 3 device
+    Container c2 = mockContainerWithDeviceRequest(1,
+        resourceName,
+        3, false);
+    // preStart
+    adapter.getDeviceResourceHandler().preStart(c2);
+    // check book keeping
+    Assert.assertEquals(0,
+        dmm.getAvailableDevices(resourceName));
+    Assert.assertEquals(3,
+        dmm.getAllUsedDevices().get(resourceName).size());
+    Assert.assertEquals(3,
+        dmm.getAllAllowedDevices().get(resourceName).size());
+    // postComplete
+    adapter.getDeviceResourceHandler().postComplete(getContainerId(1));
+    Assert.assertEquals(3,
+        dmm.getAvailableDevices(resourceName));
+    Assert.assertEquals(0,
+        dmm.getAllUsedDevices().get(resourceName).size());
+    Assert.assertEquals(3,
+        dmm.getAllAllowedDevices().get(resourceName).size());
+    // A container c3 request 0 device
+    Container c3 = mockContainerWithDeviceRequest(1,
+        resourceName,
+        0, false);
+    // preStart
+    adapter.getDeviceResourceHandler().preStart(c3);
+    // check book keeping
+    Assert.assertEquals(3,
+        dmm.getAvailableDevices(resourceName));
+    Assert.assertEquals(0,
+        dmm.getAllUsedDevices().get(resourceName).size());
+    Assert.assertEquals(3,
+        dmm.getAllAllowedDevices().get(resourceName).size());
+    // postComplete
+    adapter.getDeviceResourceHandler().postComplete(getContainerId(1));
+    Assert.assertEquals(3,
+        dmm.getAvailableDevices(resourceName));
+    Assert.assertEquals(0,
+        dmm.getAllUsedDevices().get(resourceName).size());
+    Assert.assertEquals(3,
+        dmm.getAllAllowedDevices().get(resourceName).size());
   public void testDeviceResourceUpdaterImpl() throws YarnException {
     Resource nodeResource = mock(Resource.class);
+    NodeManager.NMContext context = mock(NodeManager.NMContext.class);
     // Init an plugin
     MyPlugin plugin = new MyPlugin();
     MyPlugin spyPlugin = spy(plugin);
     String resourceName = MyPlugin.RESOURCE_NAME;
+    // Init scheduler manager
+    DeviceMappingManager dmm = new DeviceMappingManager(context);
     // Init an adapter for the plugin
     DevicePluginAdapter adapter = new DevicePluginAdapter(
-        resourceName,
-        spyPlugin);
+        resourceName, spyPlugin, dmm);
@@ -91,6 +240,235 @@ public class TestDevicePluginAdapter {
         resourceName, 3);
+  @Test
+  public void testStoreDeviceSchedulerManagerState()
+      throws IOException, YarnException {
+    NodeManager.NMContext context = mock(NodeManager.NMContext.class);
+    NMStateStoreService realStoreService = new NMMemoryStateStoreService();
+    NMStateStoreService storeService = spy(realStoreService);
+    when(context.getNMStateStore()).thenReturn(storeService);
+    doNothing().when(storeService).storeAssignedResources(isA(Container.class),
+        isA(String.class),
+        isA(ArrayList.class));
+    // Init scheduler manager
+    DeviceMappingManager dmm = new DeviceMappingManager(context);
+    ResourcePluginManager rpm = mock(ResourcePluginManager.class);
+    when(rpm.getDeviceMappingManager()).thenReturn(dmm);
+    // Init an plugin
+    MyPlugin plugin = new MyPlugin();
+    MyPlugin spyPlugin = spy(plugin);
+    String resourceName = MyPlugin.RESOURCE_NAME;
+    // Init an adapter for the plugin
+    DevicePluginAdapter adapter = new DevicePluginAdapter(
+        resourceName,
+        spyPlugin, dmm);
+    // Bootstrap, adding device
+    adapter.initialize(context);
+    adapter.createResourceHandler(context,
+        mockCGroupsHandler, mockPrivilegedExecutor);
+    adapter.getDeviceResourceHandler().bootstrap(conf);
+    // A container c0 requests 1 device
+    Container c0 = mockContainerWithDeviceRequest(0,
+        resourceName,
+        1, false);
+    // preStart
+    adapter.getDeviceResourceHandler().preStart(c0);
+    // ensure container1's resource is persistent
+    verify(storeService).storeAssignedResources(c0, resourceName,
+        Arrays.asList(Device.Builder.newInstance()
+            .setId(0)
+            .setDevPath("/dev/hdwA0")
+            .setMajorNumber(256)
+            .setMinorNumber(0)
+            .setBusID("0000:80:00.0")
+            .setHealthy(true)
+            .build()));
+  }
+  @Test
+  public void testRecoverDeviceSchedulerManagerState()
+      throws IOException, YarnException {
+    NodeManager.NMContext context = mock(NodeManager.NMContext.class);
+    NMStateStoreService realStoreService = new NMMemoryStateStoreService();
+    NMStateStoreService storeService = spy(realStoreService);
+    when(context.getNMStateStore()).thenReturn(storeService);
+    doNothing().when(storeService).storeAssignedResources(isA(Container.class),
+        isA(String.class),
+        isA(ArrayList.class));
+    // Init scheduler manager
+    DeviceMappingManager dmm = new DeviceMappingManager(context);
+    ResourcePluginManager rpm = mock(ResourcePluginManager.class);
+    when(rpm.getDeviceMappingManager()).thenReturn(dmm);
+    // Init an plugin
+    MyPlugin plugin = new MyPlugin();
+    MyPlugin spyPlugin = spy(plugin);
+    String resourceName = MyPlugin.RESOURCE_NAME;
+    // Init an adapter for the plugin
+    DevicePluginAdapter adapter = new DevicePluginAdapter(
+        resourceName,
+        spyPlugin, dmm);
+    // Bootstrap, adding device
+    adapter.initialize(context);
+    adapter.createResourceHandler(context,
+        mockCGroupsHandler, mockPrivilegedExecutor);
+    adapter.getDeviceResourceHandler().bootstrap(conf);
+    Assert.assertEquals(3,
+        dmm.getAllAllowedDevices().get(resourceName).size());
+    // mock NMStateStore
+    Device storedDevice = Device.Builder.newInstance()
+        .setId(0)
+        .setDevPath("/dev/hdwA0")
+        .setMajorNumber(256)
+        .setMinorNumber(0)
+        .setBusID("0000:80:00.0")
+        .setHealthy(true)
+        .build();
+    ConcurrentHashMap<ContainerId, Container> runningContainersMap
+        = new ConcurrentHashMap<>();
+    Container nmContainer = mock(Container.class);
+    ResourceMappings rmap = new ResourceMappings();
+    ResourceMappings.AssignedResources ar =
+        new ResourceMappings.AssignedResources();
+    ar.updateAssignedResources(
+        Arrays.asList(storedDevice));
+    rmap.addAssignedResources(resourceName, ar);
+    when(nmContainer.getResourceMappings()).thenReturn(rmap);
+    when(context.getContainers()).thenReturn(runningContainersMap);
+    // Test case 1. c0 get recovered. scheduler state restored
+    runningContainersMap.put(getContainerId(0), nmContainer);
+    adapter.getDeviceResourceHandler().reacquireContainer(
+        getContainerId(0));
+    Assert.assertEquals(3,
+        dmm.getAllAllowedDevices().get(resourceName).size());
+    Assert.assertEquals(1,
+        dmm.getAllUsedDevices().get(resourceName).size());
+    Assert.assertEquals(2,
+        dmm.getAvailableDevices(resourceName));
+    Map<Device, ContainerId> used = dmm.getAllUsedDevices().get(resourceName);
+    Assert.assertTrue(used.keySet().contains(storedDevice));
+    // Test case 2. c1 wants get recovered.
+    // But stored device is already allocated to c2
+    nmContainer = mock(Container.class);
+    rmap = new ResourceMappings();
+    ar = new ResourceMappings.AssignedResources();
+    ar.updateAssignedResources(
+        Arrays.asList(storedDevice));
+    rmap.addAssignedResources(resourceName, ar);
+    // already assigned to c1
+    runningContainersMap.put(getContainerId(2), nmContainer);
+    boolean caughtException = false;
+    try {
+      adapter.getDeviceResourceHandler().reacquireContainer(getContainerId(1));
+    } catch (ResourceHandlerException e) {
+      caughtException = true;
+    }
+    Assert.assertTrue(
+        "Should fail since requested device is assigned already",
+        caughtException);
+    // don't affect c0 allocation state
+    Assert.assertEquals(3,
+        dmm.getAllAllowedDevices().get(resourceName).size());
+    Assert.assertEquals(1,
+        dmm.getAllUsedDevices().get(resourceName).size());
+    Assert.assertEquals(2,
+        dmm.getAvailableDevices(resourceName));
+    used = dmm.getAllUsedDevices().get(resourceName);
+    Assert.assertTrue(used.keySet().contains(storedDevice));
+  }
+  @Test
+  public void testAssignedDeviceCleanupWhenStoreOpFails()
+      throws IOException, YarnException {
+    NodeManager.NMContext context = mock(NodeManager.NMContext.class);
+    NMStateStoreService realStoreService = new NMMemoryStateStoreService();
+    NMStateStoreService storeService = spy(realStoreService);
+    when(context.getNMStateStore()).thenReturn(storeService);
+    doThrow(new IOException("Exception ...")).when(storeService)
+        .storeAssignedResources(isA(Container.class),
+        isA(String.class),
+        isA(ArrayList.class));
+    // Init scheduler manager
+    DeviceMappingManager dmm = new DeviceMappingManager(context);
+    ResourcePluginManager rpm = mock(ResourcePluginManager.class);
+    when(rpm.getDeviceMappingManager()).thenReturn(dmm);
+    // Init an plugin
+    MyPlugin plugin = new MyPlugin();
+    MyPlugin spyPlugin = spy(plugin);
+    String resourceName = MyPlugin.RESOURCE_NAME;
+    // Init an adapter for the plugin
+    DevicePluginAdapter adapter = new DevicePluginAdapter(
+        resourceName,
+        spyPlugin, dmm);
+    // Bootstrap, adding device
+    adapter.initialize(context);
+    adapter.createResourceHandler(context,
+        mockCGroupsHandler, mockPrivilegedExecutor);
+    adapter.getDeviceResourceHandler().bootstrap(conf);
+    // A container c0 requests 1 device
+    Container c0 = mockContainerWithDeviceRequest(0,
+        resourceName,
+        1, false);
+    // preStart
+    boolean exception = false;
+    try {
+      adapter.getDeviceResourceHandler().preStart(c0);
+    } catch (ResourceHandlerException e) {
+      exception = true;
+    }
+    Assert.assertTrue("Should throw exception in preStart", exception);
+    // no device assigned
+    Assert.assertEquals(3,
+        dmm.getAllAllowedDevices().get(resourceName).size());
+    Assert.assertEquals(0,
+        dmm.getAllUsedDevices().get(resourceName).size());
+    Assert.assertEquals(3,
+        dmm.getAvailableDevices(resourceName));
+  }
+  private static Container mockContainerWithDeviceRequest(int id,
+      String resourceName,
+      int numDeviceRequest,
+      boolean dockerContainerEnabled) {
+    Container c = mock(Container.class);
+    when(c.getContainerId()).thenReturn(getContainerId(id));
+    Resource res = Resource.newInstance(1024, 1);
+    ResourceMappings resMapping = new ResourceMappings();
+    res.setResourceValue(resourceName, numDeviceRequest);
+    when(c.getResource()).thenReturn(res);
+    when(c.getResourceMappings()).thenReturn(resMapping);
+    ContainerLaunchContext clc = mock(ContainerLaunchContext.class);
+    Map<String, String> env = new HashMap<>();
+    if (dockerContainerEnabled) {
+      env.put(ContainerRuntimeConstants.ENV_CONTAINER_TYPE,
+          ContainerRuntimeConstants.CONTAINER_RUNTIME_DOCKER);
+    }
+    when(clc.getEnvironment()).thenReturn(env);
+    when(c.getLaunchContext()).thenReturn(clc);
+    return c;
+  }
+  private static ContainerId getContainerId(int id) {
+    return ContainerId.newContainerId(ApplicationAttemptId
+        .newInstance(ApplicationId.newInstance(1234L, 1), 1), id);
+  }
   private class MyPlugin implements DevicePlugin {
     private final static String RESOURCE_NAME = "cmpA.com/hdwA";

To unsubscribe, e-mail: common-commits-unsubscr...@hadoop.apache.org
For additional commands, e-mail: common-commits-h...@hadoop.apache.org

Reply via email to