Repository: systemml
Updated Branches:
  refs/heads/master 1fa8e126f -> bdf42c068


[SYSTEMML-445] Cleanup GPU memory management

- Simplified GPU memory management (allocate/deallocate/evict) to debug any 
OOM-related issues.
- Also, streamlined fragmentation related code into malloc and prepared it for 
future memory management policies.
- Allow user to configure the GPU eviction policy
- Also added MRU as a cache eviction policy

Closes #733.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/bdf42c06
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/bdf42c06
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/bdf42c06

Branch: refs/heads/master
Commit: bdf42c068c6e0c4ddfcef7d827cfb65f11fdff89
Parents: 1fa8e12
Author: Niketan Pansare <[email protected]>
Authored: Thu Mar 1 15:16:45 2018 -0800
Committer: Niketan Pansare <[email protected]>
Committed: Thu Mar 1 15:19:03 2018 -0800

----------------------------------------------------------------------
 conf/SystemML-config.xml.template               |   3 +
 .../java/org/apache/sysml/api/DMLScript.java    |  21 +
 .../apache/sysml/api/ScriptExecutorUtils.java   |   4 +-
 .../sysml/api/mlcontext/ScriptExecutor.java     |  21 +-
 .../java/org/apache/sysml/conf/DMLConfig.java   |   4 +-
 .../instructions/gpu/GPUInstruction.java        |   1 +
 .../instructions/gpu/context/CSRPointer.java    |   3 -
 .../instructions/gpu/context/GPUContext.java    | 538 +----------------
 .../gpu/context/GPUMemoryManager.java           | 576 +++++++++++++++++++
 .../instructions/gpu/context/GPUObject.java     |  55 +-
 .../DoublePrecisionCudaSupportFunctions.java    |  11 +
 .../LibMatrixCuDNNConvolutionAlgorithm.java     |   9 +-
 .../data/LibMatrixCuDNNInputRowFetcher.java     |   6 +-
 .../org/apache/sysml/utils/GPUStatistics.java   |  16 +-
 .../org/apache/sysml/utils/LRUCacheMap.java     |  91 ---
 .../org/apache/sysml/test/gpu/GPUTests.java     |   7 +-
 .../apache/sysml/test/unit/LRUCacheMapTest.java | 120 ----
 17 files changed, 712 insertions(+), 774 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/bdf42c06/conf/SystemML-config.xml.template
----------------------------------------------------------------------
diff --git a/conf/SystemML-config.xml.template 
b/conf/SystemML-config.xml.template
index e5f0137..a9c73c8 100644
--- a/conf/SystemML-config.xml.template
+++ b/conf/SystemML-config.xml.template
@@ -93,6 +93,9 @@
     <!-- the floating point precision. supported values are double, single -->
     <sysml.floating.point.precision>double</sysml.floating.point.precision>
     
+    <!-- the eviction policy for the GPU bufferpool. supported values are lru, 
mru, lfu, min_evict -->
+    <sysml.gpu.eviction.policy>lru</sysml.gpu.eviction.policy>
+    
    <!-- maximum wrap length for instruction and miscellaneous timer column of 
statistics -->
    <sysml.stats.maxWrapLength>30</sysml.stats.maxWrapLength>
 </root>

http://git-wip-us.apache.org/repos/asf/systemml/blob/bdf42c06/src/main/java/org/apache/sysml/api/DMLScript.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/DMLScript.java 
b/src/main/java/org/apache/sysml/api/DMLScript.java
index 4da874e..7a44838 100644
--- a/src/main/java/org/apache/sysml/api/DMLScript.java
+++ b/src/main/java/org/apache/sysml/api/DMLScript.java
@@ -108,6 +108,19 @@ public class DMLScript
                HYBRID_SPARK,   // execute matrix operations in CP or Spark
                SPARK                   // execute matrix operations in Spark
        }
+       
+       /**
+        * Eviction policies for eviction of GPU objects.
+        */
+       public enum EvictionPolicy {
+               LRU,                            // Evict the least recently 
used GPUObject. 
+               LFU,                            // Evict the least frequently 
used GPUObject. 
+               MIN_EVICT,
+               MRU,                            // 
http://www.vldb.org/conf/1985/P127.PDF
+               // TODO:
+               // ARC, // https://dbs.uni-leipzig.de/file/ARC.pdf
+               // LOOP_AWARE           // different policies for operations in 
for/while/parfor loop vs out-side the loop
+       }
 
        /**
         * Set of DMLOptions that can be set through the command line
@@ -164,6 +177,7 @@ public class DMLScript
        public static ExplainType       EXPLAIN             = 
DMLOptions.defaultOptions.explainType; // explain type
        public static String            DML_FILE_PATH_ANTLR_PARSER = 
DMLOptions.defaultOptions.filePath; // filename of dml/pydml script
        public static String            FLOATING_POINT_PRECISION = "double";    
                                                // data type to use internally
+       public static EvictionPolicy    GPU_EVICTION_POLICY = 
EvictionPolicy.LRU;                                               // currently 
employed GPU eviction policy
 
        /**
         * Global variable indicating the script type (DML or PYDML). Can be 
used
@@ -675,6 +689,13 @@ public class DMLScript
 
                // Sets the GPUs to use for this process (a range, all GPUs, 
comma separated list or a specific GPU)
                GPUContextPool.AVAILABLE_GPUS = 
dmlconf.getTextValue(DMLConfig.AVAILABLE_GPUS);
+               
+               String evictionPolicy = 
dmlconf.getTextValue(DMLConfig.GPU_EVICTION_POLICY).toUpperCase();
+               try {
+                       DMLScript.GPU_EVICTION_POLICY = 
EvictionPolicy.valueOf(evictionPolicy);
+               } catch(IllegalArgumentException e) {
+            throw new RuntimeException("Unsupported eviction policy:" + 
evictionPolicy);
+        }
 
                //Step 2: set local/remote memory if requested (for compile in 
AM context) 
                if( dmlconf.getBooleanValue(DMLConfig.YARN_APPMASTER) ){

http://git-wip-us.apache.org/repos/asf/systemml/blob/bdf42c06/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java 
b/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
index a6c276f..d6258a5 100644
--- a/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
+++ b/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
@@ -105,7 +105,9 @@ public class ScriptExecutorUtils {
                        throw e;
                } finally { // ensure cleanup/shutdown
                        if (DMLScript.USE_ACCELERATOR && 
!ec.getGPUContexts().isEmpty()) {
-                               ec.getGPUContexts().forEach(gCtx -> 
gCtx.clearTemporaryMemory());
+                               for(GPUContext gCtx : ec.getGPUContexts()) {
+                                       gCtx.clearTemporaryMemory();
+                               }
                                GPUContextPool.freeAllGPUContexts();
                        }
                        if( ConfigurationManager.isCodegenEnabled() )

http://git-wip-us.apache.org/repos/asf/systemml/blob/bdf42c06/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java 
b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
index b00b642..9b0dfc8 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
@@ -27,6 +27,7 @@ import java.util.Set;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.DMLOptions;
+import org.apache.sysml.api.DMLScript.EvictionPolicy;
 import org.apache.sysml.api.ScriptExecutorUtils;
 import org.apache.sysml.api.jmlc.JMLCUtils;
 import org.apache.sysml.api.mlcontext.MLContext.ExecutionType;
@@ -250,20 +251,30 @@ public class ScriptExecutor {
                oldGPU = DMLScript.USE_ACCELERATOR;
                DMLScript.USE_ACCELERATOR = gpu;
                DMLScript.STATISTICS_COUNT = statisticsMaxHeavyHitters;
-               
+
                // set the global compiler configuration
                try {
                        OptimizerUtils.resetStaticCompilerFlags();
                        CompilerConfig cconf = 
OptimizerUtils.constructCompilerConfig(
-                               ConfigurationManager.getCompilerConfig(), 
config);
+                                       
ConfigurationManager.getCompilerConfig(), config);
                        ConfigurationManager.setGlobalConfig(cconf);
-               } catch(DMLRuntimeException ex) {
+               } 
+               catch(DMLRuntimeException ex) {
                        throw new RuntimeException(ex);
                }
-               
+
                // set the GPUs to use for this process (a range, all GPUs, 
comma separated list or a specific GPU)
                GPUContextPool.AVAILABLE_GPUS = 
config.getTextValue(DMLConfig.AVAILABLE_GPUS);
+
+               String evictionPolicy = 
config.getTextValue(DMLConfig.GPU_EVICTION_POLICY).toUpperCase();
+               try {
+                       DMLScript.GPU_EVICTION_POLICY = 
EvictionPolicy.valueOf(evictionPolicy);
+               } 
+               catch(IllegalArgumentException e) {
+                       throw new RuntimeException("Unsupported eviction 
policy:" + evictionPolicy);
+               }
        }
+       
 
        /**
         * Reset the global flags (for example: statistics, gpu, etc)
@@ -399,7 +410,7 @@ public class ScriptExecutor {
                restoreInputsInSymbolTable();
                resetGlobalFlags();
        }
-
+       
        /**
         * Restore the input variables in the symbol table after script 
execution.
         */

http://git-wip-us.apache.org/repos/asf/systemml/blob/bdf42c06/src/main/java/org/apache/sysml/conf/DMLConfig.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/conf/DMLConfig.java 
b/src/main/java/org/apache/sysml/conf/DMLConfig.java
index f40c543..62e1bd0 100644
--- a/src/main/java/org/apache/sysml/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysml/conf/DMLConfig.java
@@ -87,6 +87,7 @@ public class DMLConfig
        public static final String AVAILABLE_GPUS       = 
"sysml.gpu.availableGPUs"; // String to specify which GPUs to use (a range, all 
GPUs, comma separated list or a specific GPU)
        public static final String SYNCHRONIZE_GPU      = 
"sysml.gpu.sync.postProcess"; // boolean: whether to synchronize GPUs after 
every instruction 
        public static final String EAGER_CUDA_FREE              = 
"sysml.gpu.eager.cudaFree"; // boolean: whether to perform eager CUDA free on 
rmvar
+       public static final String GPU_EVICTION_POLICY  = 
"sysml.gpu.eviction.policy"; // string: can be lru, lfu, min_evict
        // Fraction of available memory to use. The available memory is 
computer when the GPUContext is created
        // to handle the tradeoff on calling cudaMemGetInfo too often.
        public static final String GPU_MEMORY_UTILIZATION_FACTOR = 
"sysml.gpu.memory.util.factor";
@@ -135,6 +136,7 @@ public class DMLConfig
                _defaultVals.put(STATS_MAX_WRAP_LEN,     "30" );
                _defaultVals.put(GPU_MEMORY_UTILIZATION_FACTOR,      "0.9" );
                _defaultVals.put(AVAILABLE_GPUS,         "-1");
+               _defaultVals.put(GPU_EVICTION_POLICY,    "lru");
                _defaultVals.put(SYNCHRONIZE_GPU,        "true" );
                _defaultVals.put(EAGER_CUDA_FREE,        "false" );
                _defaultVals.put(FLOATING_POINT_PRECISION,               
"double" );
@@ -426,7 +428,7 @@ public class DMLConfig
                                COMPRESSED_LINALG, 
                                CODEGEN, CODEGEN_COMPILER, CODEGEN_OPTIMIZER, 
CODEGEN_PLANCACHE, CODEGEN_LITERALS,
                                EXTRA_FINEGRAINED_STATS, STATS_MAX_WRAP_LEN,
-                               AVAILABLE_GPUS, SYNCHRONIZE_GPU, 
EAGER_CUDA_FREE, FLOATING_POINT_PRECISION
+                               AVAILABLE_GPUS, SYNCHRONIZE_GPU, 
EAGER_CUDA_FREE, FLOATING_POINT_PRECISION, GPU_EVICTION_POLICY
                }; 
                
                StringBuilder sb = new StringBuilder();

http://git-wip-us.apache.org/repos/asf/systemml/blob/bdf42c06/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
index fe3edd2..63eb34a 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
@@ -65,6 +65,7 @@ public abstract class GPUInstruction extends Instruction {
        
        public final static String MISC_TIMER_CUDA_FREE =               "f";    
        // time spent in calling cudaFree
        public final static String MISC_TIMER_ALLOCATE =                "a";    
        // time spent to allocate memory on gpu
+       public final static String MISC_TIMER_EVICT =                   
"evict";        // time spent in eviction on gpu
        public final static String MISC_TIMER_ALLOCATE_DENSE_OUTPUT =   "ad";   
        // time spent to allocate dense output (recorded differently than 
MISC_TIMER_ALLOCATE)
        public final static String MISC_TIMER_ALLOCATE_SPARSE_OUTPUT =  "as";   
        // time spent to allocate sparse output (recorded differently than 
MISC_TIMER_ALLOCATE)
        public final static String MISC_TIMER_SET_ZERO =                "az";   
        // time spent to allocate

http://git-wip-us.apache.org/repos/asf/systemml/blob/bdf42c06/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
index d165970..a50e56a 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
@@ -298,7 +298,6 @@ public class CSRPointer {
                        // with no memory allocated on the GPU.
                        return r;
                }
-               gCtx.ensureFreeSpace(getDataTypeSizeOf(nnz2) + 
getIntSizeOf(rows + 1) + getIntSizeOf(nnz2));
                // increment the cudaCount by 1 for the allocation of all 3 
arrays
                r.val = gCtx.allocate(null, getDataTypeSizeOf(nnz2));
                r.rowPtr = gCtx.allocate(null, getIntSizeOf(rows + 1));
@@ -430,8 +429,6 @@ public class CSRPointer {
                CSRPointer that = new CSRPointer(me.getGPUContext());
 
                that.allocateMatDescrPointer();
-               long totalSize = estimateSize(me.nnz, rows);
-               that.gpuContext.ensureFreeSpace(totalSize);
 
                that.nnz = me.nnz;
                that.val = allocate(that.nnz * LibMatrixCUDA.sizeOfDataType);

http://git-wip-us.apache.org/repos/asf/systemml/blob/bdf42c06/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
index 2741697..9f41e04 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
@@ -27,36 +27,16 @@ import static jcuda.jcusolver.JCusolverDn.cusolverDnDestroy;
 import static jcuda.jcusparse.JCusparse.cusparseCreate;
 import static jcuda.jcusparse.JCusparse.cusparseDestroy;
 import static jcuda.runtime.JCuda.cudaDeviceScheduleBlockingSync;
-import static jcuda.runtime.JCuda.cudaFree;
 import static jcuda.runtime.JCuda.cudaGetDeviceCount;
-import static jcuda.runtime.JCuda.cudaMalloc;
-import static jcuda.runtime.JCuda.cudaMemGetInfo;
-import static jcuda.runtime.JCuda.cudaMemset;
 import static jcuda.runtime.JCuda.cudaSetDevice;
 import static jcuda.runtime.JCuda.cudaSetDeviceFlags;
 
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Comparator;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Iterator;
-import java.util.Map;
-import java.util.Map.Entry;
-import java.util.Set;
-
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.conf.ConfigurationManager;
-import org.apache.sysml.conf.DMLConfig;
-import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
 import org.apache.sysml.utils.GPUStatistics;
-import org.apache.sysml.utils.LRUCacheMap;
-
 import jcuda.Pointer;
 import jcuda.jcublas.cublasHandle;
 import jcuda.jcudnn.cudnnHandle;
@@ -73,10 +53,6 @@ public class GPUContext {
 
        protected static final Log LOG = 
LogFactory.getLog(GPUContext.class.getName());
        /**
-        * currently employed eviction policy
-        */
-       public final EvictionPolicy evictionPolicy = EvictionPolicy.LRU;
-       /**
         * The minimum CUDA Compute capability needed for SystemML.
         * After compute capability 3.0, 2^31 - 1 blocks and 1024 threads per 
block are supported.
         * If SystemML needs to run on an older card, this logic can be 
revisited.
@@ -107,26 +83,12 @@ public class GPUContext {
         * to launch custom CUDA kernel, specific to the active GPU for this 
GPUContext
         */
        private JCudaKernels kernels;
-
-       // Invoke cudaMemGetInfo to get available memory information. Useful if 
GPU is shared among multiple application.
-       public double GPU_MEMORY_UTILIZATION_FACTOR = 
ConfigurationManager.getDMLConfig()
-                       
.getDoubleValue(DMLConfig.GPU_MEMORY_UTILIZATION_FACTOR);
-       /**
-        * Map of free blocks allocate on GPU. maps size_of_block -> pointer on 
GPU
-        */
-       private LRUCacheMap<Long, Set<Pointer>> freeCUDASpaceMap = new 
LRUCacheMap<>();
-       /**
-        * To record size of allocated blocks
-        */
-       private HashMap<Pointer, Long> cudaBlockSizeMap = new HashMap<>();
-       /**
-        * list of allocated {@link GPUObject} instances allocated on {@link 
GPUContext#deviceNum} GPU
-        * These are matrices allocated on the GPU on which rmvar hasn't been 
called yet.
-        * If a {@link GPUObject} has more than one lock on it, it cannot be 
freed
-        * If it has zero locks on it, it can be freed, but it is preferrable 
to keep it around
-        * so that an extraneous host to dev transfer can be avoided
-        */
-       private ArrayList<GPUObject> allocatedGPUObjects = new ArrayList<>();
+       
+       private GPUMemoryManager memoryManager;
+       
+       public GPUMemoryManager getMemoryManager() {
+               return memoryManager;
+       }
 
        protected GPUContext(int deviceNum) throws DMLRuntimeException {
                this.deviceNum = deviceNum;
@@ -134,26 +96,16 @@ public class GPUContext {
 
                cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync);
 
-               long free[] = { 0 };
-               long total[] = { 0 };
-               cudaMemGetInfo(free, total);
-
                long start = -1;
                if (DMLScript.STATISTICS)
                        start = System.nanoTime();
                initializeCudaLibraryHandles();
+               
 
                if (DMLScript.STATISTICS)
                        GPUStatistics.cudaLibrariesInitTime = System.nanoTime() 
- start;
 
-               LOG.info(" GPU memory - Total: " + (total[0] * (1e-6)) + " MB, 
Available: " + (free[0] * (1e-6)) + " MB on "
-                               + this);
-
-               if (GPUContextPool.initialGPUMemBudget() > 
OptimizerUtils.getLocalMemBudget()) {
-                       LOG.warn("Potential under-utilization: GPU memory (" + 
GPUContextPool.initialGPUMemBudget()
-                                       + ") > driver memory budget (" + 
OptimizerUtils.getLocalMemBudget() + "). "
-                                       + "Consider increasing the driver 
memory budget.");
-               }
+               memoryManager = new GPUMemoryManager(this);
        }
 
        /**
@@ -175,29 +127,7 @@ public class GPUContext {
         */
        public void printMemoryInfo(String opcode) throws DMLRuntimeException {
                if (LOG.isDebugEnabled()) {
-                       long totalFreeCUDASpace = 0;
-                       for (Entry<Long, Set<Pointer>> kv : 
freeCUDASpaceMap.entrySet()) {
-                               totalFreeCUDASpace += kv.getKey() * 
kv.getValue().size();
-                       }
-                       long readLockedAllocatedMemory = 0;
-                       long writeLockedAllocatedMemory = 0;
-                       long unlockedAllocatedMemory = 0;
-                       for (GPUObject gpuObj : allocatedGPUObjects) {
-                               if (gpuObj.readLocks.longValue() > 0)
-                                       readLockedAllocatedMemory += 
gpuObj.getSizeOnDevice();
-                               else if (gpuObj.writeLock)
-                                       writeLockedAllocatedMemory += 
gpuObj.getSizeOnDevice();
-                               else
-                                       unlockedAllocatedMemory += 
gpuObj.getSizeOnDevice();
-                       }
-                       long free[] = { 0 };
-                       long total[] = { 0 };
-                       cudaMemGetInfo(free, total);
-                       long gpuFreeMemory = (long) (free[0] * 
GPU_MEMORY_UTILIZATION_FACTOR);
-                       LOG.debug(opcode + ": Total memory: " + total[0] + ", 
Free memory: " + free[0] + " (with util factor: "
-                                       + gpuFreeMemory + "), " + "Lazy unfreed 
memory: " + totalFreeCUDASpace
-                                       + ", Locked allocated memory 
(read/write): " + readLockedAllocatedMemory + "/"
-                                       + writeLockedAllocatedMemory + ", " + " 
Unlocked allocated memory: " + unlockedAllocatedMemory);
+                       LOG.debug(opcode + ": " + memoryManager.toString());
                }
        }
 
@@ -256,18 +186,18 @@ public class GPUContext {
        }
 
        /**
-        * Convenience method for {@link #allocate(String, long, int)}, 
defaults statsCount to 1.
+        * Convenience method for {@link #allocate(String, long)}.
         *
         * @param size size of data (in bytes) to allocate
         * @return jcuda pointer
         * @throws DMLRuntimeException if DMLRuntimeException occurs
         */
        public Pointer allocate(long size) throws DMLRuntimeException {
-               return allocate(null, size, 1);
+               return memoryManager.malloc(null, size);
        }
 
        /**
-        * Convenience method for {@link #allocate(String, long, int)}, 
defaults statsCount to 1.
+        * Invokes memory manager's malloc method
         *
         * @param instructionName name of instruction for which to record per 
instruction performance statistics, null if don't want to record
         * @param size            size of data (in bytes) to allocate
@@ -275,133 +205,17 @@ public class GPUContext {
         * @throws DMLRuntimeException if DMLRuntimeException occurs
         */
        public Pointer allocate(String instructionName, long size) throws 
DMLRuntimeException {
-               return allocate(instructionName, size, 1);
+               return memoryManager.malloc(instructionName, size);
        }
 
-       /**
-        * Allocates temporary space on the device.
-        * Does not update bookkeeping.
-        * The caller is responsible for freeing up after usage.
-        *
-        * @param instructionName name of instruction for which to record per 
instruction performance statistics, null if don't want to record
-        * @param size            Size of data (in bytes) to allocate
-        * @param statsCount      amount to increment the cudaAllocCount by
-        * @return jcuda Pointer
-        * @throws DMLRuntimeException if DMLRuntimeException occurs
-        */
-       public Pointer allocate(String instructionName, long size, int 
statsCount) throws DMLRuntimeException {
-               long t0 = 0, t1 = 0, end = 0;
-               Pointer A;
-               if(size < 0) {
-                       throw new DMLRuntimeException("Cannot allocate memory 
of size " + size);
-               }
-               if (freeCUDASpaceMap.containsKey(size)) {
-                       if (LOG.isTraceEnabled()) {
-                               LOG.trace(
-                                               "GPU : in allocate from 
instruction " + instructionName + ", found free block of size " + (size
-                                                               / 1024.0) + " 
Kbytes from previously allocated block on " + this);
-                       }
-                       if (instructionName != null && 
DMLScript.FINEGRAINED_STATISTICS)
-                               t0 = System.nanoTime();
-                       Set<Pointer> freeList = freeCUDASpaceMap.get(size);
-
-                       Iterator<Pointer> it = freeList.iterator(); // at this 
point, freeList should have at least one element
-                       A = it.next();
-                       it.remove();
-
-                       if (freeList.isEmpty())
-                               freeCUDASpaceMap.remove(size);
-                       if (instructionName != null && 
DMLScript.FINEGRAINED_STATISTICS)
-                               GPUStatistics
-                                               
.maintainCPMiscTimes(instructionName, GPUInstruction.MISC_TIMER_REUSE, 
System.nanoTime() - t0);
-               } else {
-                       if (LOG.isTraceEnabled()) {
-                               LOG.trace(
-                                               "GPU : in allocate from 
instruction " + instructionName + ", allocating new block of size " + (
-                                                               size / 1024.0) 
+ " Kbytes on " + this);
-                       }
-                       if (DMLScript.STATISTICS)
-                               t0 = System.nanoTime();
-                       ensureFreeSpace(instructionName, size);
-                       A = new Pointer();
-                       try {
-                               cudaMalloc(A, size);
-                       } catch(jcuda.CudaException e) {
-                               if(!DMLScript.EAGER_CUDA_FREE) {
-                                       // Strategy to avoid memory allocation 
due to potential fragmentation (a rare event):
-                                       // Step 1. First clear up lazy matrices 
and try cudaMalloc again.
-                                       // Step 2. Even if the issue persists, 
then evict all the allocated GPU objects and and try cudaMalloc again.
-                                       // After Step 2, SystemML will hold no 
pointers on GPU and the hope is that cudaMalloc will start afresh 
-                                       // by allocating objects sequentially 
with no holes.
-                                       
-                                       // Step 1:
-                                       LOG.debug("Eagerly deallocating 
rmvar-ed matrices to avoid memory allocation error due to potential 
fragmentation.");
-                                       long forcedEvictStartTime = 
DMLScript.STATISTICS ? System.nanoTime() : 0;
-                                       clearFreeCUDASpaceMap(instructionName, 
-1);
-                                       if(DMLScript.STATISTICS) {
-                                               
GPUStatistics.cudaForcedClearLazyFreedEvictTime.add(System.nanoTime()-forcedEvictStartTime);
-                                       }
-                                       try {
-                                               cudaMalloc(A, size);
-                                       } catch(jcuda.CudaException e1) {
-                                               forcedEvictStartTime = 
DMLScript.STATISTICS ? System.nanoTime() : 0;
-                                               // Step 2:
-                                               
GPUStatistics.cudaForcedClearUnpinnedMatCount.add(1);
-                                               LOG.warn("Eagerly deallocating 
unpinned matrices to avoid memory allocation error due to potential 
fragmentation. "
-                                                               + "If you see 
this warning often, we recommend that you set systemml.gpu.eager.cudaFree 
configuration property to true");
-                                               for(GPUObject toBeRemoved : 
allocatedGPUObjects) {
-                                                       if 
(!toBeRemoved.isLocked()) {
-                                                               if 
(toBeRemoved.dirty) {
-                                                                       
toBeRemoved.copyFromDeviceToHost(instructionName, true);
-                                                               }
-                                                               
toBeRemoved.clearData(true);
-                                                       }
-                                               }
-                                               if(DMLScript.STATISTICS) {
-                                                       
GPUStatistics.cudaForcedClearUnpinnedEvictTime.add(System.nanoTime()-forcedEvictStartTime);
-                                               }
-                                               cudaMalloc(A, size);
-                                       }
-                               }
-                               else {
-                                       throw new DMLRuntimeException("Unable 
to allocate memory of size " + size + " using cudaMalloc", e);
-                               }
-                       }
-                       if (DMLScript.STATISTICS)
-                               
GPUStatistics.cudaAllocTime.add(System.nanoTime() - t0);
-                       if (DMLScript.STATISTICS)
-                               GPUStatistics.cudaAllocCount.add(statsCount);
-                       if (instructionName != null && 
DMLScript.FINEGRAINED_STATISTICS)
-                               
GPUStatistics.maintainCPMiscTimes(instructionName, 
GPUInstruction.MISC_TIMER_ALLOCATE,
-                                               System.nanoTime() - t0);
-               }
-               // Set all elements to 0 since newly allocated space will 
contain garbage
-               if (DMLScript.STATISTICS)
-                       t1 = System.nanoTime();
-               if (LOG.isTraceEnabled()) {
-                       LOG.trace("GPU : in allocate from instruction " + 
instructionName + ", setting block of size " + (size
-                                       / 1024.0) + " Kbytes to zero on " + 
this);
-               }
-               cudaMemset(A, 0, size);
-               if (DMLScript.STATISTICS)
-                       end = System.nanoTime();
-               if (instructionName != null && DMLScript.FINEGRAINED_STATISTICS)
-                       GPUStatistics.maintainCPMiscTimes(instructionName, 
GPUInstruction.MISC_TIMER_SET_ZERO, end - t1);
-               if (DMLScript.STATISTICS)
-                       GPUStatistics.cudaMemSet0Time.add(end - t1);
-               if (DMLScript.STATISTICS)
-                       GPUStatistics.cudaMemSet0Count.add(1);
-               cudaBlockSizeMap.put(A, size);
-               return A;
-
-       }
 
        /**
         * Does lazy cudaFree calls.
         *
         * @param toFree {@link Pointer} instance to be freed
+        * @throws DMLRuntimeException if error
         */
-       public void cudaFreeHelper(final Pointer toFree) {
+       public void cudaFreeHelper(final Pointer toFree) throws 
DMLRuntimeException {
                cudaFreeHelper(null, toFree, DMLScript.EAGER_CUDA_FREE);
        }
 
@@ -410,8 +224,9 @@ public class GPUContext {
         *
         * @param toFree {@link Pointer} instance to be freed
         * @param eager  true if to be done eagerly
+        * @throws DMLRuntimeException if error
         */
-       public void cudaFreeHelper(final Pointer toFree, boolean eager) {
+       public void cudaFreeHelper(final Pointer toFree, boolean eager) throws 
DMLRuntimeException {
                cudaFreeHelper(null, toFree, eager);
        }
 
@@ -420,8 +235,9 @@ public class GPUContext {
         *
         * @param instructionName name of the instruction for which to record 
per instruction free time, null if do not want to record
         * @param toFree          {@link Pointer} instance to be freed
+        * @throws DMLRuntimeException if error
         */
-       public void cudaFreeHelper(String instructionName, final Pointer 
toFree) {
+       public void cudaFreeHelper(String instructionName, final Pointer 
toFree) throws DMLRuntimeException {
                cudaFreeHelper(instructionName, toFree, 
DMLScript.EAGER_CUDA_FREE);
        }
 
@@ -431,248 +247,12 @@ public class GPUContext {
         * @param instructionName name of the instruction for which to record 
per instruction free time, null if do not want to record
         * @param toFree          {@link Pointer} instance to be freed
         * @param eager           true if to be done eagerly
+        * @throws DMLRuntimeException if error
         */
-       public void cudaFreeHelper(String instructionName, final Pointer 
toFree, boolean eager) {
-               Pointer dummy = new Pointer();
-               if (toFree == dummy) { // trying to free a null pointer
-                       if (LOG.isTraceEnabled()) {
-                               LOG.trace("GPU : trying to free an empty 
pointer");
-                       }
-                       return;
-               }
-               long t0 = 0;
-               if (!cudaBlockSizeMap.containsKey(toFree))
-                       throw new RuntimeException(
-                                       "ERROR : Internal state corrupted, 
cache block size map is not aware of a block it trying to free up");
-               long size = cudaBlockSizeMap.get(toFree);
-               if (eager) {
-                       if (LOG.isTraceEnabled()) {
-                               LOG.trace("GPU : eagerly freeing cuda memory [ 
" + toFree + " ] of size " + size + " for instruction " + instructionName
-                                               + " on " + this);
-                       }
-                       if (DMLScript.STATISTICS)
-                               t0 = System.nanoTime();
-                       cudaFree(toFree);
-                       cudaBlockSizeMap.remove(toFree);
-                       if (DMLScript.STATISTICS)
-                               
GPUStatistics.cudaDeAllocTime.add(System.nanoTime() - t0);
-                       if (DMLScript.STATISTICS)
-                               GPUStatistics.cudaDeAllocCount.add(1);
-                       if (instructionName != null && 
DMLScript.FINEGRAINED_STATISTICS)
-                               
GPUStatistics.maintainCPMiscTimes(instructionName, 
GPUInstruction.MISC_TIMER_CUDA_FREE,
-                                               System.nanoTime() - t0);
-               } else {
-                       if (LOG.isTraceEnabled()) {
-                               LOG.trace("GPU : lazily freeing cuda memory of 
size " + size + " for instruction " + instructionName + " on " + this);
-                       }
-                       Set<Pointer> freeList = freeCUDASpaceMap.get(size);
-                       if (freeList == null) {
-                               freeList = new HashSet<>();
-                               freeCUDASpaceMap.put(size, freeList);
-                       }
-                       if (freeList.contains(toFree))
-                               throw new RuntimeException("GPU : Internal 
state corrupted, double free");
-                       freeList.add(toFree);
-               }
-       }
-
-       /**
-        * Thin wrapper over {@link GPUContext#evict(long)}.
-        *
-        * @param size size to check
-        * @throws DMLRuntimeException if DMLRuntimeException occurs
-        */
-       void ensureFreeSpace(long size) throws DMLRuntimeException {
-               ensureFreeSpace(null, size);
-       }
-
-       /**
-        * Thin wrapper over {@link GPUContext#evict(long)}.
-        *
-        * @param instructionName instructionName name of the instruction for 
which performance measurements are made
-        * @param size            size to check
-        * @throws DMLRuntimeException if DMLRuntimeException occurs
-        */
-       void ensureFreeSpace(String instructionName, long size) throws 
DMLRuntimeException {
-               if (size < 0)
-                       throw new DMLRuntimeException("The size cannot be 
negative:" + size);
-               else if (size >= getAvailableMemory())
-                       evict(instructionName, size);
-       }
-
-       /**
-        * Convenience wrapper over {@link GPUContext#evict(String, long)}.
-        *
-        * @param GPUSize Desired size to be freed up on the GPU
-        * @throws DMLRuntimeException If no blocks to free up or if not enough 
blocks with zero locks on them.
-        */
-       protected void evict(final long GPUSize) throws DMLRuntimeException {
-               evict(null, GPUSize);
-       }
-       
-       /**
-        * Release the set of free blocks maintained in a 
GPUObject.freeCUDASpaceMap to free up space
-        * 
-        * @param instructionName name of the instruction for which performance 
measurements are made
-        * @param neededSize      desired size to be freed up on the GPU (-1 if 
we want to eagerly free up all the blocks)
-        * @throws DMLRuntimeException If no reusable memory blocks to free up 
or if not enough matrix blocks with zero locks on them.
-        */
-       protected void clearFreeCUDASpaceMap(String instructionName,  final 
long neededSize) throws DMLRuntimeException {
-               if(neededSize < 0) {
-                       GPUStatistics.cudaForcedClearLazyFreedMatCount.add(1);
-                       while(freeCUDASpaceMap.size() > 0) {
-                               Entry<Long, Set<Pointer>> toFreeListPair = 
freeCUDASpaceMap.removeAndGetLRUEntry();
-                               
freeCUDASpaceMap.remove(toFreeListPair.getKey());
-                               for(Pointer toFree : toFreeListPair.getValue()) 
{
-                                       cudaFreeHelper(instructionName, toFree, 
true);
-                               }
-                       }
-               }
-               else {
-                       LRUCacheMap<Long, Set<Pointer>> lruCacheMap = 
freeCUDASpaceMap;
-                       while (lruCacheMap.size() > 0) {
-                               if (neededSize <= getAvailableMemory())
-                                       break;
-                               Map.Entry<Long, Set<Pointer>> toFreeListPair = 
lruCacheMap.removeAndGetLRUEntry();
-                               Set<Pointer> toFreeList = 
toFreeListPair.getValue();
-                               Long size = toFreeListPair.getKey();
-       
-                               Iterator<Pointer> it = toFreeList.iterator(); 
// at this point, freeList should have at least one element
-                               Pointer toFree = it.next();
-                               it.remove();
-       
-                               if (toFreeList.isEmpty())
-                                       lruCacheMap.remove(size);
-                               cudaFreeHelper(instructionName, toFree, true);
-                       }
-               }
-       }
-
-       /**
-        * Memory on the GPU is tried to be freed up until either a chunk of 
needed size is freed up
-        * or it fails.
-        * First the set of reusable blocks is freed up. If that isn't enough, 
the set of allocated matrix
-        * blocks with zero locks on them is freed up.
-        * The process cycles through the sorted list of allocated {@link 
GPUObject} instances. Sorting is based on
-        * number of (read) locks that have been obtained on it (reverse 
order). It repeatedly frees up
-        * blocks on which there are zero locks until the required size has 
been freed up.
-        * // TODO: update it with hybrid policy
-        *
-        * @param instructionName name of the instruction for which performance 
measurements are made
-        * @param neededSize      desired size to be freed up on the GPU
-        * @throws DMLRuntimeException If no reusable memory blocks to free up 
or if not enough matrix blocks with zero locks on them.
-        */
-       protected void evict(String instructionName, final long neededSize) 
throws DMLRuntimeException {
-               long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
-               if (LOG.isTraceEnabled()) {
-                       LOG.trace("GPU : evict called from " + instructionName 
+ " for size " + neededSize + " on " + this);
-               }
-               GPUStatistics.cudaEvictionCount.add(1);
-               if (LOG.isDebugEnabled()) {
-                       printMemoryInfo("EVICTION_CUDA_FREE_SPACE");
-               }
-
-               clearFreeCUDASpaceMap(instructionName,  neededSize);
-
-               if (neededSize <= getAvailableMemory()) {
-                       if(DMLScript.STATISTICS) {
-                               
GPUStatistics.cudaEvictTime.add(System.nanoTime() - t0);
-                       }
-                       return;
-               }
-
-               if (allocatedGPUObjects.size() == 0) {
-                       throw new DMLRuntimeException(
-                                       "There is not enough memory on device 
for this matrix, request (" + neededSize + ")");
-               }
-
-               Collections.sort(allocatedGPUObjects, new 
Comparator<GPUObject>() {
-                       @Override
-                       public int compare(GPUObject p1, GPUObject p2) {
-                               if (p1.isLocked() && p2.isLocked()) {
-                                       // Both are locked, so don't sort
-                                       return 0;
-                               } else if (p1.isLocked()) {
-                                       // Put the unlocked one to RHS
-                                       // a value less than 0 if x < y; and a 
value greater than 0 if x > y
-                                       return -1;
-                               } else if (p2.isLocked()) {
-                                       // Put the unlocked one to RHS
-                                       // a value less than 0 if x < y; and a 
value greater than 0 if x > y
-                                       return 1;
-                               } else {
-                                       // Both are unlocked
-                                       if (evictionPolicy == 
EvictionPolicy.MIN_EVICT) {
-                                               long p1Size = 0;
-                                               long p2Size = 0;
-                                               try {
-                                                       p1Size = 
p1.getSizeOnDevice() - neededSize;
-                                                       p2Size = 
p2.getSizeOnDevice() - neededSize;
-                                               } catch (DMLRuntimeException e) 
{
-                                                       throw new 
RuntimeException(e);
-                                               }
-
-                                               if (p1Size >= 0 && p2Size >= 0) 
{
-                                                       return 
Long.compare(p2Size, p1Size);
-                                               } else {
-                                                       return 
Long.compare(p1Size, p2Size);
-                                               }
-                                       } else if (evictionPolicy == 
EvictionPolicy.LRU || evictionPolicy == EvictionPolicy.LFU) {
-                                               return 
Long.compare(p2.timestamp.get(), p1.timestamp.get());
-                                       } else {
-                                               throw new 
RuntimeException("Unsupported eviction policy:" + evictionPolicy.name());
-                                       }
-                               }
-                       }
-               });
-
-               while (neededSize > getAvailableMemory() && 
allocatedGPUObjects.size() > 0) {
-                       if (LOG.isDebugEnabled()) {
-                               printMemoryInfo("EVICTION_UNLOCKED");
-                       }
-                       GPUObject toBeRemoved = 
allocatedGPUObjects.get(allocatedGPUObjects.size() - 1);
-                       if (toBeRemoved.isLocked()) {
-                               throw new DMLRuntimeException(
-                                               "There is not enough memory on 
device for this matrix, request (" + neededSize
-                                                               + "). Allocated 
GPU objects:" + allocatedGPUObjects.toString());
-                       }
-                       if (toBeRemoved.dirty) {
-                               
toBeRemoved.copyFromDeviceToHost(instructionName, true);
-                       }
-                       toBeRemoved.clearData(true);
-               }
-               if(DMLScript.STATISTICS) {
-                       GPUStatistics.cudaEvictTime.add(System.nanoTime() - t0);
-               }
-       }
-
-       /**
-        * Whether the GPU associated with this {@link GPUContext} has recorded 
the usage of a certain block.
-        *
-        * @param o the block
-        * @return true if present, false otherwise
-        */
-       public boolean isBlockRecorded(GPUObject o) {
-               return allocatedGPUObjects.contains(o);
-       }
-
-       /**
-        * @param o {@link GPUObject} instance to record
-        * @see GPUContext#allocatedGPUObjects
-        * Records the usage of a matrix block
-        */
-       public void recordBlockUsage(GPUObject o) {
-               allocatedGPUObjects.add(o);
+       public void cudaFreeHelper(String instructionName, final Pointer 
toFree, boolean eager) throws DMLRuntimeException {
+               memoryManager.free(instructionName, toFree, eager);
        }
 
-       /**
-        * @param o {@link GPUObject} instance to remove from the list of 
allocated GPU objects
-        * @see GPUContext#allocatedGPUObjects
-        * Records that a block is not used anymore
-        */
-       public void removeRecordedUsage(GPUObject o) {
-               allocatedGPUObjects.removeIf(a -> a.equals(o));
-       }
 
        /**
         * Gets the available memory on GPU that SystemML can use.
@@ -680,10 +260,7 @@ public class GPUContext {
         * @return the available memory in bytes
         */
        public long getAvailableMemory() {
-               long free[] = { 0 };
-               long total[] = { 0 };
-               cudaMemGetInfo(free, total);
-               return (long) (free[0] * GPU_MEMORY_UTILIZATION_FACTOR);
+               return memoryManager.getAvailableMemory();
        }
 
        /**
@@ -722,7 +299,9 @@ public class GPUContext {
         * @return a new {@link GPUObject} instance
         */
        public GPUObject createGPUObject(MatrixObject mo) {
-               return new GPUObject(this, mo);
+               GPUObject ret = new GPUObject(this, mo);
+               getMemoryManager().addGPUObject(ret);
+               return ret;
        }
 
        /**
@@ -870,61 +449,11 @@ public class GPUContext {
         * @throws DMLRuntimeException ?
         */
        public void clearMemory() throws DMLRuntimeException {
-               clearTemporaryMemory();
-               while (!allocatedGPUObjects.isEmpty()) {
-                       GPUObject o = allocatedGPUObjects.get(0);
-                       if (o.isDirty()) {
-                               LOG.warn("Attempted to free GPU Memory when a 
block[" + o
-                                               + "] is still on GPU memory, 
copying it back to host.");
-                               o.acquireHostRead(null);
-                       }
-                       o.clearData(true);
-               }
-               allocatedGPUObjects.clear();
+               memoryManager.clearMemory();
        }
-
-       /**
-        * Clears up the memory used to optimize cudaMalloc/cudaFree calls.
-        */
-       public void clearTemporaryMemory() {
-               // To record the cuda block sizes needed by 
allocatedGPUObjects, others are cleared up.
-               HashMap<Pointer, Long> tmpCudaBlockSizeMap = new HashMap<>();
-               for (GPUObject o : allocatedGPUObjects) {
-                       if (o.isDirty()) {
-                               if (o.isSparse()) {
-                                       CSRPointer p = 
o.getSparseMatrixCudaPointer();
-                                       if (p == null)
-                                               throw new 
RuntimeException("CSRPointer is null in clearTemporaryMemory");
-                                       if (p.rowPtr != null && 
cudaBlockSizeMap.containsKey(p.rowPtr)) {
-                                               
tmpCudaBlockSizeMap.put(p.rowPtr, cudaBlockSizeMap.get(p.rowPtr));
-                                       }
-                                       if (p.colInd != null && 
cudaBlockSizeMap.containsKey(p.colInd)) {
-                                               
tmpCudaBlockSizeMap.put(p.colInd, cudaBlockSizeMap.get(p.colInd));
-                                       }
-                                       if (p.val != null && 
cudaBlockSizeMap.containsKey(p.val)) {
-                                               tmpCudaBlockSizeMap.put(p.val, 
cudaBlockSizeMap.get(p.val));
-                                       }
-
-                               } else {
-                                       Pointer p = o.getJcudaDenseMatrixPtr();
-                                       if (p == null)
-                                               throw new 
RuntimeException("Pointer is null in clearTemporaryMemory");
-                                       tmpCudaBlockSizeMap.put(p, 
cudaBlockSizeMap.get(p));
-                               }
-                       }
-               }
-
-               // garbage collect all temporarily allocated spaces
-               for (Set<Pointer> l : freeCUDASpaceMap.values()) {
-                       for (Pointer p : l) {
-                               cudaFreeHelper(p, true);
-                       }
-               }
-               cudaBlockSizeMap.clear();
-               freeCUDASpaceMap.clear();
-
-               // Restore only those entries for which there are still blocks 
on the GPU
-               cudaBlockSizeMap.putAll(tmpCudaBlockSizeMap);
+       
+       public void clearTemporaryMemory() throws DMLRuntimeException {
+               memoryManager.clearTemporaryMemory();
        }
 
        @Override
@@ -932,11 +461,4 @@ public class GPUContext {
                return "GPUContext{" + "deviceNum=" + deviceNum + '}';
        }
 
-       /**
-        * Eviction policies for {@link GPUContext#evict(long)}.
-        */
-       public enum EvictionPolicy {
-               LRU, LFU, MIN_EVICT
-       }
-
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/bdf42c06/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
new file mode 100644
index 0000000..c0df38b
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
@@ -0,0 +1,576 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.runtime.instructions.gpu.context;
+
+import static jcuda.runtime.JCuda.cudaFree;
+import static jcuda.runtime.JCuda.cudaMalloc;
+import static jcuda.runtime.JCuda.cudaMemGetInfo;
+import static jcuda.runtime.JCuda.cudaMemset;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.atomic.LongAdder;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.conf.DMLConfig;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
+import org.apache.sysml.utils.GPUStatistics;
+
+import jcuda.Pointer;
+
+/**
+ * - All cudaFree and cudaMalloc in SystemML should go through this class to 
avoid OOM or incorrect results.
+ * - This class can be refactored in future to accept a chunk of memory ahead 
of time rather than while execution. This will only thow memory-related errors 
during startup.  
+ */
+public class GPUMemoryManager {
+       protected static final Log LOG = 
LogFactory.getLog(GPUMemoryManager.class.getName());
+       
+       // If the available free size is less than this factor, 
GPUMemoryManager will warn users of multiple programs grabbing onto GPU memory.
+       // This often happens if user tries to use both TF and SystemML, and TF 
grabs onto 90% of the memory ahead of time.
+       private static final double WARN_UTILIZATION_FACTOR = 0.7;
+       
+       // Invoke cudaMemGetInfo to get available memory information. Useful if 
GPU is shared among multiple application.
+       public double GPU_MEMORY_UTILIZATION_FACTOR = 
ConfigurationManager.getDMLConfig()
+                       
.getDoubleValue(DMLConfig.GPU_MEMORY_UTILIZATION_FACTOR);
+       
+       /**
+        * Map of free blocks allocate on GPU. maps size_of_block -> pointer on 
GPU
+        */
+       private HashMap<Long, Set<Pointer>> rmvarGPUPointers = new 
HashMap<Long, Set<Pointer>>();
+       
+       /**
+        * list of allocated {@link GPUObject} instances allocated on {@link 
GPUContext#deviceNum} GPU
+        * These are matrices allocated on the GPU on which rmvar hasn't been 
called yet.
+        * If a {@link GPUObject} has more than one lock on it, it cannot be 
freed
+        * If it has zero locks on it, it can be freed, but it is preferrable 
to keep it around
+        * so that an extraneous host to dev transfer can be avoided
+        */
+       private ArrayList<GPUObject> allocatedGPUObjects = new ArrayList<>();
+       
+       /**
+        * To record size of allocated blocks
+        */
+       private HashMap<Pointer, Long> allocatedGPUPointers = new HashMap<>();
+       
+       /**
+        * Adds the GPU object to the memory manager
+        * 
+        * @param gpuObj the handle to the GPU object
+        */
+       public void addGPUObject(GPUObject gpuObj) {
+               allocatedGPUObjects.add(gpuObj);
+       }
+       
+       /**
+        * Removes the GPU object from the memory manager
+        * 
+        * @param gpuObj the handle to the GPU object
+        */
+       public void removeGPUObject(GPUObject gpuObj) {
+               if(LOG.isDebugEnabled())
+                       LOG.debug("Removing the GPU object: " + gpuObj);
+               allocatedGPUObjects.removeIf(a -> a.equals(gpuObj));
+       }
+       
+       
+       /**
+        * Get size of allocated GPU Pointer
+        * @param ptr pointer to get size of
+        * @return either the size or -1 if no such pointer exists
+        */
+       public long getSizeAllocatedGPUPointer(Pointer ptr) {
+               if(allocatedGPUPointers.containsKey(ptr)) {
+                       return allocatedGPUPointers.get(ptr);
+               }
+               return -1;
+       }
+       
+       public GPUMemoryManager(GPUContext gpuCtx) {
+               long free[] = { 0 };
+               long total[] = { 0 };
+               cudaMemGetInfo(free, total);
+               if(free[0] < WARN_UTILIZATION_FACTOR*total[0]) {
+                       LOG.warn("Potential under-utilization: GPU memory - 
Total: " + (total[0] * (1e-6)) + " MB, Available: " + (free[0] * (1e-6)) + " MB 
on " + gpuCtx 
+                                       + ". This can happen if there are other 
processes running on the GPU at the same time.");
+               }
+               else {
+                       LOG.info("GPU memory - Total: " + (total[0] * (1e-6)) + 
" MB, Available: " + (free[0] * (1e-6)) + " MB on " + gpuCtx);
+               }
+               if (GPUContextPool.initialGPUMemBudget() > 
OptimizerUtils.getLocalMemBudget()) {
+                       LOG.warn("Potential under-utilization: GPU memory (" + 
GPUContextPool.initialGPUMemBudget()
+                                       + ") > driver memory budget (" + 
OptimizerUtils.getLocalMemBudget() + "). "
+                                       + "Consider increasing the driver 
memory budget.");
+               }
+       }
+       
+       /**
+        * Invoke cudaMalloc
+        * 
+        * @param A pointer
+        * @param size size in bytes
+        * @return allocated pointer
+        */
+       private Pointer cudaMallocWarnIfFails(Pointer A, long size) {
+               try {
+                       cudaMalloc(A, size);
+                       allocatedGPUPointers.put(A, size);
+                       return A;
+               } catch(jcuda.CudaException e) {
+                       LOG.warn("cudaMalloc failed immediately after 
cudaMemGetInfo reported that memory of size " + size + " is available. "
+                                       + "This usually happens if there are 
external programs trying to grab on to memory in parallel.");
+                       return null;
+               }
+       }
+       
+       /**
+        * Allocate pointer of the given size in bytes.
+        * 
+        * @param opcode instruction name
+        * @param size size in bytes
+        * @return allocated pointer
+        * @throws DMLRuntimeException if error
+        */
+       public Pointer malloc(String opcode, long size) throws 
DMLRuntimeException {
+               if(size < 0) {
+                       throw new DMLRuntimeException("Cannot allocate memory 
of size " + size);
+               }
+               long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
+               // Step 1: First try reusing exact match in rmvarGPUPointers to 
avoid holes in the GPU memory
+               Pointer A = getRmvarPointer(opcode, size);
+               
+               // Step 2: Allocate a new pointer in the GPU memory (since 
memory is available)
+               if(A == null && size <= getAvailableMemory()) {
+                       A = cudaMallocWarnIfFails(new Pointer(), size);
+                       if(LOG.isTraceEnabled()) {
+                               if(A == null)
+                                       LOG.trace("Couldnot allocate a new 
pointer in the GPU memory:" + size);
+                               else
+                                       LOG.trace("Allocated a new pointer in 
the GPU memory:" + size);
+                       }
+               }
+               
+               // Reusing one rmvar-ed pointer (Step 3) is preferred to 
reusing multiple pointers as the latter may not be contiguously allocated.
+               // (Step 4 or using any other policy that doesnot take memory 
into account).
+               
+               // Step 3: Try reusing non-exact match entry of rmvarGPUPointers
+               if(A == null) { 
+                       // Find minimum key that is greater than size
+                       long key = Long.MAX_VALUE;
+                       for(Long k : rmvarGPUPointers.keySet()) {
+                               key = k > size ? Math.min(key, k) : key;
+                       }
+                       if(key != Long.MAX_VALUE) {
+                               A = getRmvarPointer(opcode, key);
+                               // To avoid potential for holes in the GPU 
memory
+                               guardedCudaFree(A);
+                               A = cudaMallocWarnIfFails(new Pointer(), size);
+                               if(LOG.isTraceEnabled()) {
+                                       if(A == null)
+                                               LOG.trace("Couldnot reuse 
non-exact match of rmvarGPUPointers:" + size);
+                                       else
+                                               LOG.trace("Reuses a non-exact 
match from rmvarGPUPointers:" + size);
+                               }
+                       }
+               }
+               
+               // Step 3.b: An optimization missing so as not to over-engineer 
malloc:
+               // Try to find minimal number of contiguously allocated pointer.
+               
+               // Evictions of matrix blocks are expensive (as they might lead 
them to be written to disk in case of smaller CPU budget) 
+               // than doing cuda free/malloc/memset. So, rmvar-ing every 
blocks (step 4) is preferred to eviction (step 5).
+               
+               // Step 4: Eagerly free-up rmvarGPUPointers and check if memory 
is available on GPU
+               if(A == null) {
+                       Set<Pointer> toFree = new HashSet<Pointer>();
+                       for(Set<Pointer> ptrs : rmvarGPUPointers.values()) {
+                               toFree.addAll(ptrs);
+                       }
+                       for(Pointer ptr : toFree) {
+                               guardedCudaFree(ptr);
+                       }
+                       if(size <= getAvailableMemory()) {
+                               A = cudaMallocWarnIfFails(new Pointer(), size);
+                               if(LOG.isTraceEnabled()) {
+                                       if(A == null)
+                                               LOG.trace("Couldnot allocate a 
new pointer in the GPU memory after eager free:" + size);
+                                       else
+                                               LOG.trace("Allocated a new 
pointer in the GPU memory after eager free:" + size);
+                               }
+                       }
+               }
+               
+               addMiscTime(opcode, GPUStatistics.cudaAllocTime, 
GPUStatistics.cudaAllocCount, GPUInstruction.MISC_TIMER_ALLOCATE, t0);
+               
+               // Step 5: Try eviction based on the given policy
+               if(A == null) {
+                       t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
+                       // Sort based on the eviction policy
+                       Collections.sort(allocatedGPUObjects, new 
GPUComparator(size));
+                       while (size > getAvailableMemory() && 
allocatedGPUObjects.size() > 0) {
+                               GPUObject toBeRemoved = 
allocatedGPUObjects.get(allocatedGPUObjects.size() - 1);
+                               if (toBeRemoved.isLocked()) {
+                                       // All remaining blocks will also be 
locked
+                                       break;
+                               }
+                               else {
+                                       // Perform eviction
+                                       if (toBeRemoved.dirty) {
+                                               
toBeRemoved.copyFromDeviceToHost(opcode, true);
+                                       }
+                                       toBeRemoved.clearData(true);
+                               }
+                       }
+                       addMiscTime(opcode, GPUStatistics.cudaEvictionCount, 
GPUStatistics.cudaEvictTime, GPUInstruction.MISC_TIMER_EVICT, t0);
+                       if(size <= getAvailableMemory()) {
+                               A = cudaMallocWarnIfFails(new Pointer(), size);
+                               if(LOG.isTraceEnabled()) {
+                                       if(A == null)
+                                               LOG.trace("Couldnot allocate a 
new pointer in the GPU memory after eviction:" + size);
+                                       else
+                                               LOG.trace("Allocated a new 
pointer in the GPU memory after eviction:" + size);
+                               }
+                       }
+               }
+               
+               if(A == null) {
+                       throw new DMLRuntimeException("There is not enough 
memory on device for this matrix, request (" + size + "). "
+                                       + toString());
+               }
+               
+               t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
+               cudaMemset(A, 0, size);
+               addMiscTime(opcode, GPUStatistics.cudaMemSet0Time, 
GPUStatistics.cudaMemSet0Count, GPUInstruction.MISC_TIMER_SET_ZERO, t0);
+               return A;
+       }
+       
+       /**
+        * Note: This method should not be called from an iterator as it 
removes entries from allocatedGPUPointers and rmvarGPUPointers
+        * 
+        * @param toFree pointer to call cudaFree method on
+        */
+       private void guardedCudaFree(Pointer toFree) {
+               if (toFree != new Pointer()) {
+                       if(allocatedGPUPointers.containsKey(toFree)) {
+                               Long size = allocatedGPUPointers.remove(toFree);
+                               if(rmvarGPUPointers.containsKey(size) && 
rmvarGPUPointers.get(size).contains(toFree)) {
+                                       remove(rmvarGPUPointers, size, toFree);
+                               }
+                               if(LOG.isDebugEnabled())
+                                       LOG.debug("Free-ing up the pointer: " + 
toFree);
+                               cudaFree(toFree);
+                       }
+                       else {
+                               throw new RuntimeException("Attempting to free 
an unaccounted pointer:" + toFree);
+                       }
+               }
+       }
+       
+       /**
+        * Deallocate the pointer
+        * 
+        * @param opcode instruction name
+        * @param toFree pointer to free
+        * @param eager whether to deallocate eagerly
+        * @throws DMLRuntimeException if error
+        */
+       public void free(String opcode, Pointer toFree, boolean eager) throws 
DMLRuntimeException {
+               Pointer dummy = new Pointer();
+               if (toFree == dummy) { // trying to free a null pointer
+                       return;
+               }
+               if (eager) {
+                       long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
+                       guardedCudaFree(toFree);
+                       addMiscTime(opcode, GPUStatistics.cudaDeAllocTime, 
GPUStatistics.cudaDeAllocCount, GPUInstruction.MISC_TIMER_CUDA_FREE, t0);
+               }
+               else {
+                       if (!allocatedGPUPointers.containsKey(toFree))
+                               throw new RuntimeException("ERROR : Internal 
state corrupted, cache block size map is not aware of a block it trying to free 
up");
+                       long size = allocatedGPUPointers.get(toFree);
+                       Set<Pointer> freeList = rmvarGPUPointers.get(size);
+                       if (freeList == null) {
+                               freeList = new HashSet<Pointer>();
+                               rmvarGPUPointers.put(size, freeList);
+                       }
+                       if (freeList.contains(toFree))
+                               throw new RuntimeException("GPU : Internal 
state corrupted, double free");
+                       freeList.add(toFree);
+               }
+       }
+       
+       /**
+        * Clear the allocated GPU objects
+        * 
+        * @throws DMLRuntimeException if error
+        */
+       public void clearMemory() throws DMLRuntimeException {
+               // First deallocate all the GPU objects
+               for(GPUObject gpuObj : allocatedGPUObjects) {
+                       if(gpuObj.isDirty()) {
+                               LOG.debug("Attempted to free GPU Memory when a 
block[" + gpuObj + "] is still on GPU memory, copying it back to host.");
+                               gpuObj.acquireHostRead(null);
+                       }
+                       gpuObj.clearData(true);
+               }
+               allocatedGPUObjects.clear();
+               
+               // Then clean up remaining allocated GPU pointers 
+               Set<Pointer> remainingPtr = new 
HashSet<>(allocatedGPUPointers.keySet());
+               for(Pointer toFree : remainingPtr) {
+                       guardedCudaFree(toFree); // cleans up 
allocatedGPUPointers and rmvarGPUPointers as well
+               }
+       }
+       
+       /**
+        * Get all pointers withing allocatedGPUObjects such that GPUObject is 
in dirty state
+        * 
+        * @return set of pointers
+        */
+       private HashSet<Pointer> getDirtyPointers() {
+               HashSet<Pointer> nonTemporaryPointers = new HashSet<Pointer>();
+               for (GPUObject o : allocatedGPUObjects) {
+                       if(o.isDirty()) {
+                               if (o.isSparse()) {
+                                       CSRPointer p = 
o.getSparseMatrixCudaPointer();
+                                       if (p == null)
+                                               throw new 
RuntimeException("CSRPointer is null in clearTemporaryMemory");
+                                       if (p.rowPtr != null) {
+                                               
nonTemporaryPointers.add(p.rowPtr);
+                                       }
+                                       if (p.colInd != null) {
+                                               
nonTemporaryPointers.add(p.colInd);
+                                       }
+                                       if (p.val != null) {
+                                               nonTemporaryPointers.add(p.val);
+                                       }
+
+                               } else {
+                                       Pointer p = o.getJcudaDenseMatrixPtr();
+                                       if (p == null)
+                                               throw new 
RuntimeException("Pointer is null in clearTemporaryMemory");
+                                       nonTemporaryPointers.add(p);
+                               }
+                       }
+               }
+               
+               return nonTemporaryPointers;
+       }
+       
+       /**
+        * Performs a non-in operation
+        * 
+        * @param superset superset of pointer
+        * @param subset subset of pointer
+        * @return pointers such that: superset - subset
+        */
+       private Set<Pointer> nonIn(Set<Pointer> superset, Set<Pointer> subset) {
+               Set<Pointer> ret = new HashSet<Pointer>();
+               for(Pointer superPtr : superset) {
+                       if(!subset.contains(superPtr)) {
+                               ret.add(superPtr);
+                       }
+               }
+               return ret;
+       }
+       
+       /**
+        * Clears up the memory used by non-dirty pointers.
+        */
+       public void clearTemporaryMemory() {
+               // To record the cuda block sizes needed by 
allocatedGPUObjects, others are cleared up.
+               Set<Pointer> temporaryPointers = 
nonIn(allocatedGPUPointers.keySet(), getDirtyPointers());
+               for(Pointer tmpPtr : temporaryPointers) {
+                       guardedCudaFree(tmpPtr);
+               }
+       }
+       
+       /**
+        * Convenient method to add misc timers
+        * 
+        * @param opcode opcode
+        * @param globalGPUTimer member of GPUStatistics
+        * @param globalGPUCounter member of GPUStatistics
+        * @param instructionLevelTimer member of GPUInstruction
+        * @param startTime start time
+        */
+       private void addMiscTime(String opcode, LongAdder globalGPUTimer, 
LongAdder globalGPUCounter, String instructionLevelTimer, long startTime) {
+               if(DMLScript.STATISTICS) {
+                       long totalTime = System.nanoTime() - startTime;
+                       globalGPUTimer.add(totalTime);
+                       globalGPUCounter.add(1);
+                       if (opcode != null && DMLScript.FINEGRAINED_STATISTICS)
+                               GPUStatistics.maintainCPMiscTimes(opcode, 
instructionLevelTimer, totalTime);
+               }
+       }
+       
+       /**
+        * Convenient method to add misc timers
+        * 
+        * @param opcode opcode
+        * @param instructionLevelTimer member of GPUInstruction
+        * @param startTime start time
+        */
+       private void addMiscTime(String opcode, String instructionLevelTimer, 
long startTime) {
+               if (opcode != null && DMLScript.FINEGRAINED_STATISTICS)
+                       GPUStatistics.maintainCPMiscTimes(opcode, 
instructionLevelTimer, System.nanoTime() - startTime);
+       }
+       
+       /**
+        * Get any pointer of the given size from rmvar-ed pointers (applicable 
if eager cudaFree is set to false)
+        * 
+        * @param opcode opcode
+        * @param size size in bytes
+        * @return pointer
+        */
+       private Pointer getRmvarPointer(String opcode, long size) {
+               if (rmvarGPUPointers.containsKey(size)) {
+                       if(LOG.isTraceEnabled())
+                               LOG.trace("Getting rmvar-ed pointers for size:" 
+ size);
+                       long t0 = opcode != null && 
DMLScript.FINEGRAINED_STATISTICS ?  System.nanoTime() : 0;
+                       Pointer A = remove(rmvarGPUPointers, size); // remove 
from rmvarGPUPointers as you are not calling cudaFree
+                       addMiscTime(opcode, GPUInstruction.MISC_TIMER_REUSE, 
t0);
+                       return A;
+               }
+               else {
+                       return null;
+               }
+       }
+       
+       /**
+        * Remove any pointer in the given hashmap
+        * 
+        * @param hm hashmap of size, pointers
+        * @param size size in bytes
+        * @return the pointer that was removed
+        */
+       private Pointer remove(HashMap<Long, Set<Pointer>> hm, long size) {
+               Pointer A = hm.get(size).iterator().next();
+               remove(hm, size, A);
+               return A;
+       }
+       
+       /**
+        * Remove a specific pointer in the given hashmap
+        * 
+        * @param hm hashmap of size, pointers
+        * @param size size in bytes
+        * @param ptr pointer to be removed
+        */
+       private void remove(HashMap<Long, Set<Pointer>> hm, long size, Pointer 
ptr) {
+               hm.get(size).remove(ptr);
+               if (hm.get(size).isEmpty())
+                       hm.remove(size);
+       }
+       
+       
+       /**
+        * Print debugging information
+        */
+       public String toString() {
+               long sizeOfLockedGPUObjects = 0; long numLockedGPUObjects = 0;
+               long sizeOfUnlockedGPUObjects = 0; long numUnlockedGPUObjects = 
0;
+               for(GPUObject gpuObj : allocatedGPUObjects) {
+                       try {
+                               if(gpuObj.isLocked()) {
+                                       numLockedGPUObjects++;
+                                       sizeOfLockedGPUObjects += 
gpuObj.getSizeOnDevice();
+                               }
+                               else {
+                                       numUnlockedGPUObjects++;
+                                       sizeOfUnlockedGPUObjects += 
gpuObj.getSizeOnDevice();
+                               }
+                       } catch (DMLRuntimeException e) {
+                               throw new RuntimeException(e);
+                       }
+               }
+               long totalMemoryAllocated = 0;
+               for(Long numBytes : allocatedGPUPointers.values()) {
+                       totalMemoryAllocated += numBytes;
+               }
+               return "Num of GPU objects: [unlocked:" + numUnlockedGPUObjects 
+ ", locked:" + numLockedGPUObjects + "]. "
+                               + "Size of GPU objects in bytes: [unlocked:" + 
sizeOfUnlockedGPUObjects + ", locked:" + sizeOfLockedGPUObjects + "]. "
+                               + "Total memory allocated by the current GPU 
context in bytes:" + totalMemoryAllocated;
+       }
+       
+       /**
+        * Gets the available memory on GPU that SystemML can use.
+        *
+        * @return the available memory in bytes
+        */
+       public long getAvailableMemory() {
+               long free[] = { 0 };
+               long total[] = { 0 };
+               cudaMemGetInfo(free, total);
+               return (long) (free[0] * GPU_MEMORY_UTILIZATION_FACTOR);
+       }
+       
+       /**
+        * Class that governs the eviction policy
+        */
+       public static class GPUComparator implements Comparator<GPUObject> {
+               private long neededSize;
+               public GPUComparator(long neededSize) {
+                       this.neededSize = neededSize;
+               }
+               @Override
+               public int compare(GPUObject p1, GPUObject p2) {
+                       if (p1.isLocked() && p2.isLocked()) {
+                               // Both are locked, so don't sort
+                               return 0;
+                       } else if (p1.isLocked()) {
+                               // Put the unlocked one to RHS
+                               // a value less than 0 if x < y; and a value 
greater than 0 if x > y
+                               return -1;
+                       } else if (p2.isLocked()) {
+                               // Put the unlocked one to RHS
+                               // a value less than 0 if x < y; and a value 
greater than 0 if x > y
+                               return 1;
+                       } else {
+                               // Both are unlocked
+                               if (DMLScript.GPU_EVICTION_POLICY == 
DMLScript.EvictionPolicy.MIN_EVICT) {
+                                       long p1Size = 0;
+                                       long p2Size = 0;
+                                       try {
+                                               p1Size = p1.getSizeOnDevice() - 
neededSize;
+                                               p2Size = p2.getSizeOnDevice() - 
neededSize;
+                                       } catch (DMLRuntimeException e) {
+                                               throw new RuntimeException(e);
+                                       }
+
+                                       if (p1Size >= 0 && p2Size >= 0) {
+                                               return Long.compare(p2Size, 
p1Size);
+                                       } else {
+                                               return Long.compare(p1Size, 
p2Size);
+                                       }
+                               } else {
+                                       return Long.compare(p2.timestamp.get(), 
p1.timestamp.get());
+                               }
+                       }
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/bdf42c06/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
index c8e70bf..538103f 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
@@ -120,7 +120,6 @@ public class GPUObject {
                                long rows = me.mat.getNumRows();
                                long cols = me.mat.getNumColumns();
                                long size = rows * cols * 
LibMatrixCUDA.sizeOfDataType;
-                               me.gpuContext.ensureFreeSpace((int) size);
                                that.jcudaDenseMatrixPtr = allocate(size);
                                cudaMemcpy(that.jcudaDenseMatrixPtr, 
me.jcudaDenseMatrixPtr, size, cudaMemcpyDeviceToDevice);
                        }
@@ -205,7 +204,6 @@ public class GPUObject {
                Pointer nnzPerRowPtr = null;
                Pointer nnzTotalDevHostPtr = null;
 
-               gCtx.ensureFreeSpace(getIntSizeOf(rows + 1));
                nnzPerRowPtr = gCtx.allocate(getIntSizeOf(rows));
                nnzTotalDevHostPtr = gCtx.allocate(getIntSizeOf(1));
 
@@ -270,7 +268,6 @@ public class GPUObject {
                        cudaFreeHelper(getJcudaDenseMatrixPtr());
                        jcudaDenseMatrixPtr = null;
                }
-               getGPUContext().recordBlockUsage(this);
        }
 
        /**
@@ -285,11 +282,13 @@ public class GPUObject {
                }
                this.jcudaDenseMatrixPtr = densePtr;
                this.isSparse = false;
+               if(LOG.isDebugEnabled()) {
+                       LOG.debug("Setting dense pointer of size " + 
getGPUContext().getMemoryManager().getSizeAllocatedGPUPointer(densePtr));
+               }
                if (getJcudaSparseMatrixPtr() != null) {
                        getJcudaSparseMatrixPtr().deallocate();
                        jcudaSparseMatrixPtr = null;
                }
-               getGPUContext().recordBlockUsage(this);
        }
 
        /**
@@ -456,15 +455,6 @@ public class GPUObject {
                return eitherAllocated;
        }
 
-       public boolean isInputAllocated() {
-               boolean eitherAllocated = (getJcudaDenseMatrixPtr() != null || 
getJcudaSparseMatrixPtr() != null);
-               boolean isAllocatedOnThisGPUContext = 
getGPUContext().isBlockRecorded(this);
-               if (eitherAllocated && !isAllocatedOnThisGPUContext) {
-                       LOG.warn("GPU : A block was allocated but was not on 
this GPUContext, GPUContext=" + getGPUContext());
-               }
-               return eitherAllocated && isAllocatedOnThisGPUContext;
-       }
-
        /**
         * Allocates a sparse and empty {@link GPUObject}
         * This is the result of operations that are both non zero matrices.
@@ -543,7 +533,6 @@ public class GPUObject {
                                int cols = toIntExact(mat.getNumColumns());
                                Pointer nnzPerRowPtr = null;
                                Pointer nnzTotalDevHostPtr = null;
-                               gCtx.ensureFreeSpace(getIntSizeOf(rows + 1));
                                nnzPerRowPtr = 
gCtx.allocate(getIntSizeOf(rows));
                                nnzTotalDevHostPtr = 
gCtx.allocate(getIntSizeOf(1));
                                
LibMatrixCUDA.cudaSupportFunctions.cusparsennz(cusparseHandle, 
cusparseDirection.CUSPARSE_DIRECTION_ROW, rows, cols, matDescr, 
getJcudaDenseMatrixPtr(), rows,
@@ -696,18 +685,21 @@ public class GPUObject {
         * @throws DMLRuntimeException if there is no locked GPU Object or if 
could not obtain a {@link GPUContext}
         */
        private void updateReleaseLocks() throws DMLRuntimeException {
-               GPUContext.EvictionPolicy evictionPolicy = 
getGPUContext().evictionPolicy;
+               DMLScript.EvictionPolicy evictionPolicy = 
DMLScript.GPU_EVICTION_POLICY;
                switch (evictionPolicy) {
-               case LRU:
-                       timestamp.set(System.nanoTime());
-                       break;
-               case LFU:
-                       timestamp.addAndGet(1);
-                       break;
-               case MIN_EVICT: /* Do Nothing */
-                       break;
-               default:
-                       throw new CacheException("The eviction policy is not 
supported:" + evictionPolicy.name());
+                       case LRU:
+                               timestamp.set(System.nanoTime());
+                               break;
+                       case LFU:
+                               timestamp.addAndGet(1);
+                               break;
+                       case MIN_EVICT: /* Do Nothing */
+                               break;
+                       case MRU:
+                               timestamp.set(-System.nanoTime());
+                               break;
+                       default:
+                               throw new CacheException("The eviction policy 
is not supported:" + evictionPolicy.name());
                }
        }
 
@@ -731,6 +723,9 @@ public class GPUObject {
        public void releaseOutput() throws DMLRuntimeException {
                releaseWriteLock();
                updateReleaseLocks();
+               // Currently, there is no convenient way to acquireDeviceModify 
independently of dense/sparse format. 
+               // Hence, allowing resetting releaseOutput again. 
+               // Ideally, we would want to throw CacheException("Attempting 
to release an output that was not acquired via acquireDeviceModify") if 
!isDirty()
                dirty = true;
                if (!isAllocated())
                        throw new CacheException("Attempting to release an 
output before allocating it");
@@ -782,7 +777,6 @@ public class GPUObject {
                jcudaDenseMatrixPtr = null;
                jcudaSparseMatrixPtr = null;
                resetReadWriteLock();
-               getGPUContext().removeRecordedUsage(this);
        }
 
        protected long getSizeOnDevice() throws DMLRuntimeException {
@@ -938,7 +932,8 @@ public class GPUObject {
                        MatrixBlock tmp = new 
MatrixBlock(toIntExact(mat.getNumRows()), toIntExact(mat.getNumColumns()), 
false);
                        tmp.allocateDenseBlock();
                        
LibMatrixCUDA.cudaSupportFunctions.deviceToHost(getGPUContext(),
-                               getJcudaDenseMatrixPtr(), 
tmp.getDenseBlockValues(), instName, isEviction);
+                                               getJcudaDenseMatrixPtr(), 
tmp.getDenseBlockValues(), instName, isEviction);
+                       
                        tmp.recomputeNonZeros();
                        mat.acquireModify(tmp);
                        mat.release();
@@ -1008,7 +1003,7 @@ public class GPUObject {
         */
        public void clearData(boolean eager) throws DMLRuntimeException {
                deallocateMemoryOnDevice(eager);
-               getGPUContext().removeRecordedUsage(this);
+               getGPUContext().getMemoryManager().removeGPUObject(this);
        }
 
        /**
@@ -1046,6 +1041,10 @@ public class GPUObject {
                sb.append(", writeLock=").append(writeLock);
                sb.append(", sparse? ").append(isSparse);
                sb.append(", 
dims=[").append(mat.getNumRows()).append(",").append(mat.getNumColumns()).append("]");
+               if(jcudaDenseMatrixPtr != null)
+                       sb.append(", densePtr=").append(jcudaDenseMatrixPtr);
+               if(jcudaSparseMatrixPtr != null)
+                       sb.append(", sparsePtr=").append(jcudaSparseMatrixPtr);
                sb.append('}');
                return sb.toString();
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/bdf42c06/src/main/java/org/apache/sysml/runtime/matrix/data/DoublePrecisionCudaSupportFunctions.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/DoublePrecisionCudaSupportFunctions.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/DoublePrecisionCudaSupportFunctions.java
index fb70c13..f31806d 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/DoublePrecisionCudaSupportFunctions.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/DoublePrecisionCudaSupportFunctions.java
@@ -22,6 +22,8 @@ import static jcuda.runtime.JCuda.cudaMemcpy;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToHost;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyHostToDevice;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
@@ -40,6 +42,8 @@ import jcuda.jcusparse.cusparseMatDescr;
 
 public class DoublePrecisionCudaSupportFunctions implements 
CudaSupportFunctions {
 
+       private static final Log LOG = 
LogFactory.getLog(DoublePrecisionCudaSupportFunctions.class.getName());
+       
        @Override
        public int cusparsecsrgemm(cusparseHandle handle, int transA, int 
transB, int m, int n, int k,
                        cusparseMatDescr descrA, int nnzA, Pointer csrValA, 
Pointer csrRowPtrA, Pointer csrColIndA,
@@ -161,6 +165,13 @@ public class DoublePrecisionCudaSupportFunctions 
implements CudaSupportFunctions
        @Override
        public void deviceToHost(GPUContext gCtx, Pointer src, double[] dest, 
String instName, boolean isEviction) throws DMLRuntimeException {
                long t1 = DMLScript.FINEGRAINED_STATISTICS  && instName != 
null? System.nanoTime() : 0;
+               if(src == null)
+                       throw new DMLRuntimeException("The source pointer in 
deviceToHost is null");
+               if(dest == null)
+                       throw new DMLRuntimeException("The destination array in 
deviceToHost is null");
+               if(LOG.isDebugEnabled()) {
+                       LOG.debug("deviceToHost: src of size " + 
gCtx.getMemoryManager().getSizeAllocatedGPUPointer(src) + " (in bytes) -> dest 
of size " + (dest.length*Double.BYTES)  + " (in bytes).");
+               }
                cudaMemcpy(Pointer.to(dest), src, 
((long)dest.length)*Sizeof.DOUBLE, cudaMemcpyDeviceToHost);
                if(DMLScript.FINEGRAINED_STATISTICS && instName != null) 
                        GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_DEVICE_TO_HOST, System.nanoTime() - t1);

http://git-wip-us.apache.org/repos/asf/systemml/blob/bdf42c06/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
index 8050d1e..835cb15 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
@@ -95,8 +95,13 @@ public class LibMatrixCuDNNConvolutionAlgorithm implements 
java.lang.AutoCloseab
                        cudnnDestroyFilterDescriptor(filterDesc);
                if(convDesc != null)
                        cudnnDestroyConvolutionDescriptor(convDesc);
-               if(sizeInBytes != 0)
-                       gCtx.cudaFreeHelper(instName, workSpace);
+               if(sizeInBytes != 0) {
+                       try {
+                               gCtx.cudaFreeHelper(instName, workSpace);
+                       } catch (DMLRuntimeException e) {
+                               throw new RuntimeException(e);
+                       }
+               }
                if(DMLScript.FINEGRAINED_STATISTICS)
                        GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t3);
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/bdf42c06/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNInputRowFetcher.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNInputRowFetcher.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNInputRowFetcher.java
index 5a7cad3..33f2cb5 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNInputRowFetcher.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNInputRowFetcher.java
@@ -77,6 +77,10 @@ public class LibMatrixCuDNNInputRowFetcher extends 
LibMatrixCUDA implements java
         */
        @Override
        public void close() {
-               gCtx.cudaFreeHelper(outPointer, true);
+               try {
+                       gCtx.cudaFreeHelper(outPointer, true);
+               } catch (DMLRuntimeException e) {
+                       throw new RuntimeException(e);
+               }
        }
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/bdf42c06/src/main/java/org/apache/sysml/utils/GPUStatistics.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/GPUStatistics.java 
b/src/main/java/org/apache/sysml/utils/GPUStatistics.java
index b7eb00c..d12f4dd 100644
--- a/src/main/java/org/apache/sysml/utils/GPUStatistics.java
+++ b/src/main/java/org/apache/sysml/utils/GPUStatistics.java
@@ -60,8 +60,6 @@ public class GPUStatistics {
        public static LongAdder cudaToDevCount = new LongAdder();
        public static LongAdder cudaFromDevCount = new LongAdder();
        public static LongAdder cudaEvictionCount = new LongAdder();
-       public static LongAdder cudaForcedClearLazyFreedMatCount = new 
LongAdder();
-       public static LongAdder cudaForcedClearUnpinnedMatCount = new 
LongAdder();
 
        // Per instruction miscellaneous timers.
        // Used to record events in a CP Heavy Hitter instruction and
@@ -97,8 +95,6 @@ public class GPUStatistics {
                cudaToDevCount.reset();
                cudaFromDevCount.reset();
                cudaEvictionCount.reset();
-               cudaForcedClearLazyFreedMatCount.reset();
-               cudaForcedClearUnpinnedMatCount.reset();
                resetMiscTimers();
        }
 
@@ -197,25 +193,21 @@ public class GPUStatistics {
                sb.append("CUDA/CuLibraries init time:\t" + 
String.format("%.3f", cudaInitTime*1e-9) + "/"
                                + String.format("%.3f", 
cudaLibrariesInitTime*1e-9) + " sec.\n");
                sb.append("Number of executed GPU inst:\t" + 
getNoOfExecutedGPUInst() + ".\n");
-               sb.append("GPU mem tx time  
(alloc/dealloc/set0/toDev/fromDev/evict/forcedEvict(lazy/unpinned)):\t"
+               sb.append("GPU mem tx time  
(alloc/dealloc/set0/toDev/fromDev/evict):\t"
                                + String.format("%.3f", 
cudaAllocTime.longValue()*1e-9) + "/"
                                + String.format("%.3f", 
cudaDeAllocTime.longValue()*1e-9) + "/"
                                + String.format("%.3f", 
cudaMemSet0Time.longValue()*1e-9) + "/"
                                + String.format("%.3f", 
cudaToDevTime.longValue()*1e-9) + "/"
                                + String.format("%.3f", 
cudaFromDevTime.longValue()*1e-9) + "/"
-                               + String.format("%.3f", 
cudaEvictTime.longValue()*1e-9) + "/("
-                               + String.format("%.3f", 
cudaForcedClearLazyFreedEvictTime.longValue()*1e-9) + "/" 
-                               + String.format("%.3f", 
cudaForcedClearUnpinnedEvictTime.longValue()*1e-9) + ") sec.\n");
-               sb.append("GPU mem tx count 
(alloc/dealloc/set0/toDev/fromDev/evict/forcedEvict(lazy/unpinned)):\t"
+                               + String.format("%.3f", 
cudaEvictTime.longValue()*1e-9) + " sec.\n");
+               sb.append("GPU mem tx count 
(alloc/dealloc/set0/toDev/fromDev/evict):\t"
                                + cudaAllocCount.longValue() + "/"
                                + cudaDeAllocCount.longValue() + "/"
                                + cudaMemSet0Count.longValue() + "/"
                                + cudaSparseConversionCount.longValue() + "/"
                                + cudaToDevCount.longValue() + "/"
                                + cudaFromDevCount.longValue() + "/"
-                               + cudaEvictionCount.longValue() + "/("
-                               + cudaForcedClearLazyFreedMatCount.longValue() 
+ "/"
-                               + cudaForcedClearUnpinnedMatCount.longValue() + 
").\n");
+                               + cudaEvictionCount.longValue() + ".\n");
                sb.append("GPU conversion time  
(sparseConv/sp2dense/dense2sp):\t"
                                + String.format("%.3f", 
cudaSparseConversionTime.longValue()*1e-9) + "/"
                                + String.format("%.3f", 
cudaSparseToDenseTime.longValue()*1e-9) + "/"

http://git-wip-us.apache.org/repos/asf/systemml/blob/bdf42c06/src/main/java/org/apache/sysml/utils/LRUCacheMap.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/LRUCacheMap.java 
b/src/main/java/org/apache/sysml/utils/LRUCacheMap.java
deleted file mode 100644
index 830b102..0000000
--- a/src/main/java/org/apache/sysml/utils/LRUCacheMap.java
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.sysml.utils;
-
-
-import org.apache.sysml.runtime.DMLRuntimeException;
-
-import java.util.LinkedHashMap;
-import java.util.Map;
-
-/**
- * An access ordered LRU Cache Map which conforms to the {@link Map} interface
- * while also providing the ability to get the least recently used entry
- * @param <K> the key type
- * @param <V> the value type
- */
-public class LRUCacheMap<K,V> extends LinkedHashMap<K,V> {
-
-       private static final long serialVersionUID = 7078404374799241418L;
-
-/**
-   * Creates an access-ordered {@link LRUCacheMap}
-   */
-  public LRUCacheMap() {
-    // An access-ordered LinkedHashMap is instantiated with the default 
initial capacity and load factors
-    super(16, 0.75f, true);
-  }
-
-  // Private variables to assist in capturing the lease recently used entry
-  private boolean evictLRU = false;
-  private Map.Entry<K,V> lastEvictedEntry = null;
-
-  /**
-   * Removes and gets the least recently used entry
-   * @return  the lease recently used entry
-   * @throws DMLRuntimeException if the internal state is somehow corrupted
-   */
-  public Map.Entry<K,V> removeAndGetLRUEntry() throws DMLRuntimeException {
-    lastEvictedEntry = null;
-    if (size() <= 0){
-      return null;
-    }
-
-    // The idea is to set removing the eldest entry to true and then putting 
in a dummy
-    // entry (null, null). the removeEldestEntry will capture the eldest entry 
and is available
-    // to return via a class member variable.
-    evictLRU = true;
-    V previous = super.put(null, null);
-    remove(null);
-    if (previous != null){
-      throw new DMLRuntimeException("ERROR : Internal state of LRUCacheMap 
invalid - a value for the key 'null' is already present");
-    }
-    evictLRU = false;
-    Map.Entry<K,V> toRet = lastEvictedEntry;
-    return toRet;
-  }
-
-  @Override
-  protected boolean removeEldestEntry(Map.Entry<K,V> eldest) {
-    if (evictLRU) {
-      lastEvictedEntry = eldest;
-      return true;
-    }
-    return false;
-  }
-
-  @Override
-  public V put (K k, V v){
-    if (k == null)
-      throw new IllegalArgumentException("ERROR: an entry with a null key was 
tried to be inserted in to the LRUCacheMap");
-    return super.put (k, v);
-  }
-
-
-}

http://git-wip-us.apache.org/repos/asf/systemml/blob/bdf42c06/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java 
b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
index 501e545..e006fd2 100644
--- a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
+++ b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
@@ -29,6 +29,7 @@ import java.util.concurrent.locks.ReentrantLock;
 
 import org.apache.spark.sql.SparkSession;
 import org.apache.sysml.api.mlcontext.MLContext;
+import org.apache.sysml.api.mlcontext.MLResults;
 import org.apache.sysml.api.mlcontext.Matrix;
 import org.apache.sysml.api.mlcontext.Script;
 import org.apache.sysml.api.mlcontext.ScriptFactory;
@@ -324,8 +325,9 @@ public abstract class GPUTests extends AutomatedTestBase {
                MLContext cpuMLC = new MLContext(spark);
                List<Object> outputs = new ArrayList<>();
                Script script = 
ScriptFactory.dmlFromString(scriptStr).in(inputs).out(outStrs);
+               MLResults res = cpuMLC.execute(script);
                for (String outStr : outStrs) {
-                       Object output = cpuMLC.execute(script).get(outStr);
+                       Object output = res.get(outStr);
                        outputs.add(output);
                }
                cpuMLC.close();
@@ -355,8 +357,9 @@ public abstract class GPUTests extends AutomatedTestBase {
                        gpuMLC.setStatistics(true);
                        List<Object> outputs = new ArrayList<>();
                        Script script = 
ScriptFactory.dmlFromString(scriptStr).in(inputs).out(outStrs);
+                       MLResults res = gpuMLC.execute(script);
                        for (String outStr : outStrs) {
-                               Object output = 
gpuMLC.execute(script).get(outStr);
+                               Object output = res.get(outStr);
                                outputs.add(output);
                        }
                        gpuMLC.close();

http://git-wip-us.apache.org/repos/asf/systemml/blob/bdf42c06/src/test/java/org/apache/sysml/test/unit/LRUCacheMapTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/unit/LRUCacheMapTest.java 
b/src/test/java/org/apache/sysml/test/unit/LRUCacheMapTest.java
deleted file mode 100644
index 09df5a0..0000000
--- a/src/test/java/org/apache/sysml/test/unit/LRUCacheMapTest.java
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.sysml.test.unit;
-
-import org.apache.sysml.utils.LRUCacheMap;
-import org.junit.Assert;
-import org.junit.Test;
-
-import java.util.Map;
-
-public class LRUCacheMapTest {
-
-  @Test
-  public void test1() throws Exception {
-    LRUCacheMap<String, Long> m = new LRUCacheMap<String, Long>();
-    m.put("k1", 10l);
-    m.put("k2", 20l);
-    m.put("k3", 30l);
-    m.put("k4", 40l);
-
-    Map.Entry<String, Long> e = m.removeAndGetLRUEntry();
-    Assert.assertEquals("k1", e.getKey());
-  }
-
-  @Test
-  public void test2() throws Exception {
-    LRUCacheMap<String, Long> m = new LRUCacheMap<String, Long>();
-    m.put("k1", 10l);
-    m.put("k2", 20l);
-    m.put("k3", 30l);
-    m.put("k4", 40l);
-    m.get("k1");
-
-    Map.Entry<String, Long> e = m.removeAndGetLRUEntry();
-    Assert.assertEquals("k2", e.getKey());
-  }
-
-  @Test(expected = IllegalArgumentException.class)
-  public void test3() {
-    LRUCacheMap<String, Long> m = new LRUCacheMap<String, Long>();
-    m.put(null, 10l);
-  }
-
-  @Test
-  public void test4() throws Exception {
-    LRUCacheMap<String, Long> m = new LRUCacheMap<String, Long>();
-    m.put("k1", 10l);
-    m.put("k2", 20l);
-    m.put("k3", 30l);
-    m.put("k4", 40l);
-    m.remove("k1");
-    m.remove("k2");
-
-    Map.Entry<String, Long> e = m.removeAndGetLRUEntry();
-    Assert.assertEquals("k3", e.getKey());
-  }
-
-  @Test
-  public void test5() throws Exception {
-    LRUCacheMap<String, Long> m = new LRUCacheMap<String, Long>();
-    m.put("k1", 10l);
-    m.put("k2", 20l);
-    m.put("k1", 30l);
-
-    Map.Entry<String, Long> e = m.removeAndGetLRUEntry();
-    Assert.assertEquals("k2", e.getKey());
-  }
-
-  @Test
-  public void test6() throws Exception {
-    LRUCacheMap<String, Long> m = new LRUCacheMap<String, Long>();
-    m.put("k1", 10l);
-    m.put("k2", 20l);
-    m.put("k3", 30l);
-    m.put("k4", 40l);
-    m.put("k5", 50l);
-    m.put("k6", 60l);
-    m.put("k7", 70l);
-    m.put("k8", 80l);
-    m.get("k4");
-
-
-    Map.Entry<String, Long> e;
-    e = m.removeAndGetLRUEntry();
-    Assert.assertEquals("k1", e.getKey());
-    e = m.removeAndGetLRUEntry();
-    Assert.assertEquals("k2", e.getKey());
-    e = m.removeAndGetLRUEntry();
-    Assert.assertEquals("k3", e.getKey());
-    e = m.removeAndGetLRUEntry();
-    Assert.assertEquals("k5", e.getKey());
-    e = m.removeAndGetLRUEntry();
-    Assert.assertEquals("k6", e.getKey());
-    e = m.removeAndGetLRUEntry();
-    Assert.assertEquals("k7", e.getKey());
-    e = m.removeAndGetLRUEntry();
-    Assert.assertEquals("k8", e.getKey());
-    e = m.removeAndGetLRUEntry();
-    Assert.assertEquals("k4", e.getKey());
-
-  }
-
-
-}
\ No newline at end of file

Reply via email to