SINGA-175 Add memory management APIs and implement a subclass using CNMeM Add base memory pool class. Implement two subclasses, CnMemPool and CudaMemPool. Add test for the memory pools.
TODO replace Device* to std::shared_ptr<Device> to avoid memory error because the order of destructing device and tensor are dynamic (device may be freed before tensors) Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/077d13e8 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/077d13e8 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/077d13e8 Branch: refs/heads/dev Commit: 077d13e8052aa92679909b619966481a383a651f Parents: ce3e6dc Author: [email protected] <[email protected]> Authored: Wed Jun 22 20:26:41 2016 +0800 Committer: [email protected] <[email protected]> Committed: Wed Jun 22 20:26:41 2016 +0800 ---------------------------------------------------------------------- CMakeLists.txt | 6 +- include/singa/core/device.h | 5 ++ include/singa/core/memory.h | 46 +++++++++++++++ include/singa/model/loss.h | 2 +- src/core/device/cuda_gpu.cc | 59 +++++++++++++++++-- src/core/memory/memory.cc | 69 ++++++++++++++++++++++ src/proto/core.proto | 13 +++++ test/singa/test_memory.cc | 111 ++++++++++++++++++++++++++++++++++++ test/singa/test_mse.cc | 13 ++++- test/singa/test_tensor_math.cc | 4 ++ 10 files changed, 319 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/077d13e8/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/CMakeLists.txt b/CMakeLists.txt index f6240d2..c34b6ce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,14 +13,15 @@ ENDIF() #message(STATUS "${CMAKE_CXX_FLAGS}") SET(SINGA_INCLUDE_DIR - "${CMAKE_SOURCE_DIR}/include;${CMAKE_SOURCE_DIR}/lib/cnmem/lib;${PROJECT_BINARY_DIR}") + #"${CMAKE_SOURCE_DIR}/include;${CMAKE_SOURCE_DIR}/lib/cnmem/lib;${CMAKE_SOURCE_DIR}/lib/cnmen/include;${PROJECT_BINARY_DIR}") + "${CMAKE_SOURCE_DIR}/include;${CMAKE_SOURCE_DIR}/lib/cnmem/include;${PROJECT_BINARY_DIR}") #message(STATUS "include path: ${SINGA_INCLUDE_DIR}") INCLUDE_DIRECTORIES(${SINGA_INCLUDE_DIR}) #OPTION(CPU_ONLY "use GPU libs" OFF) OPTION(USE_CBLAS "Use CBlas libs" ON) OPTION(USE_CUDA "Use Cuda libs" ON) -OPTION(USE_CUDNN "Use Cudnn libs" ON) +OPTION(USE_CUDNN "Use Cudnn libs" OFF) OPTION(USE_OPENCV "Use opencv" OFF) OPTION(USE_LMDB "Use LMDB libs" OFF) @@ -38,5 +39,6 @@ SET(LIBRARY_OUTPUT_PATH ${PROJECT_BINARY_DIR}/lib) SET(EXECUTABLE_OUTPUT_PATH ${PROJECT_BINARY_DIR}/bin) ADD_SUBDIRECTORY(lib/cnmem) +LIST(APPEND SINGA_LINKER_LIBS cnmem) ADD_SUBDIRECTORY(src) ADD_SUBDIRECTORY(test) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/077d13e8/include/singa/core/device.h ---------------------------------------------------------------------- diff --git a/include/singa/core/device.h b/include/singa/core/device.h index 8c95dc7..fc98a23 100644 --- a/include/singa/core/device.h +++ b/include/singa/core/device.h @@ -147,6 +147,8 @@ class CudaGPU : public Device { ~CudaGPU(); CudaGPU(int id = 0, int num_executors = 1, string scheduler = "sync", string vm = "gc-only"); + CudaGPU(const MemPoolConf& mem_conf, + int id = 0, int num_executors = 1, string scheduler = "sync"); void SetRandSeed(unsigned seed) override; static void DeviceQuery(); @@ -180,6 +182,9 @@ class CudaGPU : public Device { /// Free cpu memory. void Free(void* ptr) override; + + private: + DeviceMemPool* pool; }; /// CudaCPU which uses cudaMallocHost to allocate pinned memory for host. http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/077d13e8/include/singa/core/memory.h ---------------------------------------------------------------------- diff --git a/include/singa/core/memory.h b/include/singa/core/memory.h index db09043..e4e1e63 100644 --- a/include/singa/core/memory.h +++ b/include/singa/core/memory.h @@ -19,10 +19,56 @@ #ifndef SINGA_CORE_MEMORY_H_ #define SINGA_CORE_MEMORY_H_ +#include "cnmem.h" +#include <mutex> + namespace singa { /// Manage device memory pool including garbage collection, memory opt. class VirtualMemory {}; +class DeviceMemPool { + public: + virtual void InitPool() = 0; + virtual void Malloc(void** ptr, const size_t size) = 0; + virtual void Free(void* ptr) = 0; + virtual ~DeviceMemPool(){}; +}; + +class CnMemPool : public DeviceMemPool { + public: + int status = 1; + + void InitPool(); + + /// numDevices: total number of available GPU cards. + /// initSize: all devices will be allocated with this size + /// manager_flags: pool manager flag (one for all devices) + /// flag = 0; default flag + /// flag = 1: Prevent the manager from growing its memory consumption + /// flag = 2; Prevent the manager from stealing memory. + void InitPool(int numDevices, size_t initSize, unsigned flag); + + void Malloc(void** ptr, const size_t size); + void Free(void* ptr); + + // release all memory and set cnmem manager to unintialized + ~CnMemPool(); + + private: + // whether the (global) memory pool has been initialized + static bool initialized; + // lock on the initialized variable + static std::mutex mtx; +}; + +class CudaMemPool : public DeviceMemPool { + public: + void InitPool(){}; + void Malloc(void** ptr, const size_t size); + void Free(void* ptr); + ~CudaMemPool(){}; +}; + } // namespace singa #endif // SINGA_CORE_MEMORY_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/077d13e8/include/singa/model/loss.h ---------------------------------------------------------------------- diff --git a/include/singa/model/loss.h b/include/singa/model/loss.h index 6a23067..dcf0da4 100644 --- a/include/singa/model/loss.h +++ b/include/singa/model/loss.h @@ -35,7 +35,7 @@ class Loss { loss.ParseFromString(conf); Setup(loss); } - + virtual ~Loss(){}; /// Set meta fields from user configurations. virtual void Setup(const LossConf& conf) {} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/077d13e8/src/core/device/cuda_gpu.cc ---------------------------------------------------------------------- diff --git a/src/core/device/cuda_gpu.cc b/src/core/device/cuda_gpu.cc index a47f6fe..d9a0985 100644 --- a/src/core/device/cuda_gpu.cc +++ b/src/core/device/cuda_gpu.cc @@ -22,7 +22,7 @@ #include <cuda_runtime.h> #include <curand.h> #include <chrono> - +#include <iostream> #include "singa/core/device.h" #include "singa/utils/cuda_utils.h" namespace singa { @@ -42,6 +42,8 @@ CudaGPU::~CudaGPU() { CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(status); } #endif + delete pool; + LOG(INFO) << "device has been deleted"; } CudaGPU::CudaGPU(int id, int num_executors, @@ -67,6 +69,48 @@ CudaGPU::CudaGPU(int id, int num_executors, auto status = cudnnCreate(&ctx_.cudnn_handle); CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(status); #endif // USE_CUDNN + + // initialize cnmem memory management as default + pool = new CnMemPool(); + ((CnMemPool*)pool)->InitPool(); +} + +CudaGPU::CudaGPU(const MemPoolConf& mem_conf,int id, int num_executors, + string scheduler) + : Device(id, num_executors, scheduler, "gc-only") { + if (id == -1) + id = FindDevice(0); + lang_ = kCuda; + ctx_.stream = NULL; // use the default sync stream + // TODO(wangwei) create one handle for each steam? + CUDA_CHECK(cudaSetDevice(FindDevice(0))); + // use curandCreateGeneratorHost for CudaHost device + CURAND_CHECK( + curandCreateGenerator(&ctx_.curand_generator, CURAND_RNG_PSEUDO_DEFAULT)); + auto seed = std::chrono::system_clock::now().time_since_epoch().count(); + SetRandSeed(seed); + // TODO(wangwei) if one generator per stream, then need diff offset per gen? + CURAND_CHECK(curandSetGeneratorOffset(ctx_.curand_generator, 0)); + CUBLAS_CHECK(cublasCreate(&(ctx_.cublas_handle))); + +#ifdef USE_CUDNN + // TODO(wangwei) create one handle for each stream? + auto status = cudnnCreate(&ctx_.cudnn_handle); + CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(status); +#endif // USE_CUDNN + + // initialize memory management for cuda devices + string memoryPoolType = mem_conf.type(); + if(memoryPoolType.compare("cnmem") == 0) { + pool = new CnMemPool(); + int num_devices = mem_conf.num_devices(); + size_t alloc_size = mem_conf.alloc_size(); + unsigned flag = mem_conf.cnmemflag(); + ((CnMemPool*)pool)->InitPool(num_devices, alloc_size, flag); + } + else { + pool = new CudaMemPool(); + } } void CudaGPU::SetRandSeed(unsigned seed) { @@ -90,7 +134,8 @@ void CudaGPU::CopyToFrom(void* dst, const void* src, size_t nBytes, void* CudaGPU::Malloc(int size) { void* ptr = nullptr; if (size > 0) { - CUDA_CHECK(cudaMalloc(&ptr, size)); + //CUDA_CHECK(cudaMalloc((void**)&ptr,size)); + pool->Malloc((void**)&ptr,size); CUDA_CHECK(cudaMemset(ptr, 0, size)); } return ptr; @@ -98,8 +143,14 @@ void* CudaGPU::Malloc(int size) { /// Free cpu memory. void CudaGPU::Free(void* ptr) { - if (ptr != nullptr) - CUDA_CHECK(cudaFree(ptr)); + LOG(INFO) << "Cuda free is called"; + LOG(INFO) << "pool pointer" << pool << "\n"; + LOG(INFO) << "pool status:" << ((CnMemPool*)pool)->status; + if (ptr != nullptr) { + //CUDA_CHECK(cudaFree(ptr)); + pool->Free(ptr); + } + LOG(INFO) << "free memory is successed"; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/077d13e8/src/core/memory/memory.cc ---------------------------------------------------------------------- diff --git a/src/core/memory/memory.cc b/src/core/memory/memory.cc index a1cf5db..c5878a6 100644 --- a/src/core/memory/memory.cc +++ b/src/core/memory/memory.cc @@ -18,3 +18,72 @@ #include "singa/core/memory.h" +#include "singa/utils/logging.h" +#include <iostream> + +namespace singa { + +bool singa::CnMemPool::initialized = false; +std::mutex singa::CnMemPool::mtx; + +void CnMemPool::InitPool(int numDevices, size_t initSize, unsigned flag) { + mtx.lock(); + if(!initialized) { + CHECK_GE(numDevices, 1); + cnmemDevice_t* settingPtr = new cnmemDevice_t[numDevices]; + for(int i = 0; i < numDevices; i++) { + settingPtr[i].device = i; + settingPtr[i].size = initSize; + settingPtr[i].numStreams = 0; + settingPtr[i].streams = NULL; + settingPtr[i].streamSizes = 0; + } + cnmemStatus_t status = cnmemInit(numDevices, settingPtr, flag); + CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS) << " " << cnmemGetErrorString(status); + delete[] settingPtr; + initialized = true; + } + mtx.unlock(); +} + +void CnMemPool::InitPool() { + int defaultNumDevices = 1; + size_t defaultSize = 1000000U; + InitPool(defaultNumDevices,defaultSize,cnmemManagerFlags_t::CNMEM_FLAGS_DEFAULT); +} + +CnMemPool::~CnMemPool() { + mtx.lock(); + if(initialized) { + cnmemStatus_t status = cnmemFinalize(); + CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS) << " " << cnmemGetErrorString(status); + initialized = false; + } + mtx.unlock(); + LOG(INFO) << "cnmem has been freed"; +} + + +void CnMemPool::Malloc(void** ptr, const size_t size) { + cnmemStatus_t status = cnmemMalloc(ptr,size,NULL); + CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS) << " " << cnmemGetErrorString(status); +} + +void CnMemPool::Free(void* ptr) { + LOG(INFO) << "cnmem free is called !!!!!!!!!!!"; + cnmemStatus_t status = cnmemFree(ptr,NULL); + CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS) << " " << cnmemGetErrorString(status); + LOG(INFO) << "cnmem free is terminated"; +} + +void CudaMemPool::Malloc(void** ptr, const size_t size) { + cudaError_t status = cudaMalloc(ptr,size); + CHECK_EQ(status, cudaError_t::cudaSuccess); +} + +void CudaMemPool::Free(void* ptr) { + cudaError_t status = cudaFree(ptr); + CHECK_EQ(status, cudaError_t::cudaSuccess); +} + +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/077d13e8/src/proto/core.proto ---------------------------------------------------------------------- diff --git a/src/proto/core.proto b/src/proto/core.proto index 88d7f12..cf6e193 100644 --- a/src/proto/core.proto +++ b/src/proto/core.proto @@ -44,3 +44,16 @@ enum CopyDirection { kDeviceToDevice = 3; kNumDirection = 4; } + +// configuration for device memory pool +message MemPoolConf { + optional string type = 1 [default = "cnmem"]; + optional uint32 num_devices = 2 [default = 1]; + // allocation size for each device + optional uint32 alloc_size = 3 [default = 10000000]; + // memory manager flag for cnmem + // cnmemflag = 0: default flag + // cnmemflag = 1: prevent the manager from growing its memory consumption + // cnmemflag = 2: prevent the manager from stealing memory + optional uint32 cnmemflag = 4 [default = 0]; +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/077d13e8/test/singa/test_memory.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_memory.cc b/test/singa/test_memory.cc new file mode 100644 index 0000000..f5e464d --- /dev/null +++ b/test/singa/test_memory.cc @@ -0,0 +1,111 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#include "gtest/gtest.h" +#include "singa/utils/logging.h" +#include "singa/core/memory.h" +#include "singa/singa_config.h" +#include <sys/time.h> + +#ifdef USE_CUDA +TEST(CnmemPool, PoolInit) { + singa::CnMemPool pool; + pool.InitPool(); +} + +TEST(CnmemPool, PoolInitAll) { + singa::CnMemPool pool; + int nDevices; + cudaGetDeviceCount(&nDevices); + CHECK_GE(nDevices,1); + pool.InitPool(nDevices,1000000U,0); +} + +TEST(CnmemPool, UsePool) { + singa::CnMemPool pool; + pool.InitPool(); + int numOfTests = 10; + int numOfWriteVsRead = 3; + int allocSize = 1000000U; + for(int i = 0; i < numOfTests; i++) { + int** memPtrs = new int*[numOfWriteVsRead]; + for(int j = 0; j < numOfWriteVsRead; j++) { + pool.Malloc((void**)(&memPtrs[j]), allocSize); + } + pool.Free(memPtrs[0]); + delete[] memPtrs; + } +} + +TEST(CudaMemPool, UsePool) { + singa::CudaMemPool pool; + int numOfTests = 10; + int numOfWriteVsRead = 3; + int allocSize = 1000000U; + for(int i = 0; i < numOfTests; i++) { + int** memPtrs = new int*[numOfWriteVsRead]; + for(int j = 0; j < numOfWriteVsRead; j++) { + pool.Malloc((void**)(&memPtrs[j]), allocSize); + } + pool.Free(memPtrs[0]); + delete[] memPtrs; + } +} + +TEST(MemPool, CompareCudaCnmem) { + singa::CudaMemPool cudaPool; + singa::CnMemPool cnPool; + cnPool.InitPool(); + + int numOfTests = 10000; + int allocSize = 1000000U; + struct timeval start,end; + double t1,t2; + + singa::DeviceMemPool* pool = NULL; + pool = &cnPool; + + gettimeofday(&start,NULL); + for(int i = 0; i < numOfTests; i++) { + int* memPtrs = NULL; + pool->Malloc((void**)&memPtrs, allocSize); + pool->Free(memPtrs); + } + gettimeofday(&end,NULL); + + t1 = start.tv_sec * 1000 + start.tv_usec/1000; + t2 = end.tv_sec * 1000 + end.tv_usec/1000; + LOG(INFO) << "cnmem time: " << t2-t1 << " ms" << std::endl; + + pool = &cudaPool; + gettimeofday(&start,NULL); + for(int i = 0; i < numOfTests; i++) { + int* memPtrs = NULL; + pool->Malloc((void**)&memPtrs, allocSize); + pool->Free(memPtrs); + } + gettimeofday(&end,NULL); + + t1 = start.tv_sec * 1000 + start.tv_usec/1000; + t2 = end.tv_sec * 1000 + end.tv_usec/1000; + LOG(INFO) << "cuda time: " << t2-t1 << " ms" << std::endl; +} +#endif // USE_CUDA http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/077d13e8/test/singa/test_mse.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_mse.cc b/test/singa/test_mse.cc index 2c02273..7c6066e 100644 --- a/test/singa/test_mse.cc +++ b/test/singa/test_mse.cc @@ -68,11 +68,11 @@ TEST_F(TestMSE, CppBackward) { #endif #ifdef USE_CUDA TEST_F(TestMSE, CudaForward) { - singa::MSE mse; + singa::MSE* mse = new singa::MSE(); singa::CudaGPU dev; p.ToDevice(&dev); t.ToDevice(&dev); - Tensor loss = mse.Forward(p, t); + Tensor loss = mse->Forward(p, t); loss.ToHost(); auto ldat = loss.data<const float*>(); @@ -85,6 +85,12 @@ TEST_F(TestMSE, CudaForward) { } EXPECT_FLOAT_EQ(ldat[i], 0.5 * l); } + LOG(INFO) << "Before delete pxxxxxxxxxxxxxxxxxxxxxxxx"; + p.ToHost(); + LOG(INFO) << "Before delete tyyyyyyyyyyyyyyyyyyyyyyy"; + t.ToHost(); + LOG(INFO) << "terminate-xxxxxxxxxxxxxxxxxx-"; + delete mse; } TEST_F(TestMSE, CudaBackward) { singa::MSE mse; @@ -98,5 +104,8 @@ TEST_F(TestMSE, CudaBackward) { for (size_t i = 0; i < grad.Size(); i++) EXPECT_FLOAT_EQ(gdat[i], (1.0f / p.shape().at(0)) * (pdat[i] - tdat[i])); + p.ToHost(); + t.ToHost(); + } #endif http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/077d13e8/test/singa/test_tensor_math.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_tensor_math.cc b/test/singa/test_tensor_math.cc index 170b96c..b18e465 100644 --- a/test/singa/test_tensor_math.cc +++ b/test/singa/test_tensor_math.cc @@ -302,6 +302,8 @@ TEST_F(TestTensorMath, MultCuda) { EXPECT_FLOAT_EQ(oPtr[i * 4 + j], x[i]); } } + d.ToHost(); + p.ToHost(); } TEST_F(TestTensorMath, AddColumnCuda) { @@ -479,6 +481,7 @@ TEST_F(TestTensorMath, SumRowsCuda) { } EXPECT_FLOAT_EQ(tptr[i], tmp); } + d.ToHost(); } TEST_F(TestTensorMath, SumColumnCuda) { singa::CudaGPU dev; @@ -495,5 +498,6 @@ TEST_F(TestTensorMath, SumColumnCuda) { } EXPECT_FLOAT_EQ(tptr[i], tmp); } + d.ToHost(); } #endif
