SINGA-171 - Create CppDevice and CudaDevice Add CppDevice and CudaDevice API. Implement CppDevice and add test for it. There is link error for cudnn.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/282712ca Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/282712ca Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/282712ca Branch: refs/heads/dev Commit: 282712caf1582bdc4e23d89fcc14d27eb0c7ad8e Parents: b491875 Author: Wei Wang <[email protected]> Authored: Tue May 17 17:24:40 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Thu May 19 14:01:24 2016 +0800 ---------------------------------------------------------------------- include/singa/core/common.h | 5 +- include/singa/core/device.h | 108 +++++++++++++++++----- include/singa/core/tensor.h | 35 +------ include/singa/utils/cuda.h | 94 +++++++++++++++++++ src/core/device/cpp_device.cc | 19 +++- src/core/device/cuda_device.cc | 132 +++++++++++++++++++++++++++ src/core/device/device.cc | 43 +++++---- src/core/tensor/tensor.cc | 176 ++++++++++++++---------------------- src/proto/core.proto | 13 ++- test/singa/test_cpp_device.cc | 71 +++++++++++++++ test/singa/test_tensor_math.cc | 16 +--- 11 files changed, 509 insertions(+), 203 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/282712ca/include/singa/core/common.h ---------------------------------------------------------------------- diff --git a/include/singa/core/common.h b/include/singa/core/common.h index 2f5b167..0fa301a 100644 --- a/include/singa/core/common.h +++ b/include/singa/core/common.h @@ -25,6 +25,7 @@ #ifdef USE_CUDA #include <cuda_runtime.h> #include <cublas_v2.h> +#include <curand.h> #ifdef USE_CUDNN #include <cudnn.h> #endif @@ -36,8 +37,6 @@ namespace lib { typedef struct _Cpp { } Cpp; /// To implemente functions using cuda libraries typedef struct _Cuda { } Cuda; -/// To implement function using cudnn -typedef struct _Cudnn { } Cudnn; /// To implement function using opencl libraries typedef struct _Opencl { } Opencl; } // namespace lib @@ -69,10 +68,10 @@ class Blob { typedef struct _Context { std::mt19937 random_generator; - unsigned long long seed; #ifdef USE_CUDA cublasHandle_t cublas_handle; cudaStream_t stream; + curandGenerator_t curand_generator; #ifdef USE_CUDNN cudnnHandle_t cudnn_handle; #endif http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/282712ca/include/singa/core/device.h ---------------------------------------------------------------------- diff --git a/include/singa/core/device.h b/include/singa/core/device.h index 9022041..29b7677 100644 --- a/include/singa/core/device.h +++ b/include/singa/core/device.h @@ -48,6 +48,8 @@ class CallbackArg { typedef function<void(CallbackArg*)> CallbackFn; /// Allocate memory and execute Tensor operations. +/// There are three types of devices distinguished by their programming +/// languages, namely cpp, cuda and opencl. class Device { public: /// Operation has a function, and read/write blobs. @@ -63,8 +65,8 @@ class Device { /// max mem size to use (in MB), identifier of scheduler type (default /// scheduler run operations synchronously) and virtual memory type (default /// vm only provides garbage collection). - Device(int id, int num_executors = 16, string scheduler = "sync", - string vm = "gc-only"); + Device(int id, int num_executors, string scheduler, string vm); + virtual void SetRandSeed(unsigned seed) = 0; /// Called by Tensor. Blob* NewBlob(int size); @@ -73,14 +75,16 @@ class Device { void FreeBlob(Blob* blob); /// Copy data within or across devices. - void CopyData(Blob* dst, const Blob& src, int len, int dst_offset, - int src_offset); + void CopyDataToFrom(Blob* dst, Blob* src, size_t nBytes, + CopyDirection direction, int dst_offset, int src_offset); - void CopyDataFromHostPtr(Blob* dst, const void* src, size_t size); + void CopyDataFromHostPtr(Blob* dst, const void* src, size_t nBytes, + size_t dst_offset = 0); /// Submit the operation to the device, which may execute it right now or /// delay it depending on the scheduler. - void Exec(function<void(Context*)> fn, const vector<Blob*> read_blobs, - const vector<Blob*> write_blobs, bool use_rand_generator = false); + void Exec(function<void(Context*)>&& fn, const vector<Blob*> read_blobs, + const vector<Blob*> write_blobs, + bool use_rand_generator = false); // Wait for one event. // void WaitFor(); @@ -88,14 +92,19 @@ class Device { /// wait for all operations submitted to this device. void Sync(); - LibType device_lib() const { return device_lib_; } - LibType nn_lib() const { return nn_lib_; } + DeviceType type() const { + return device_type_; + } Device* host() const { return host_; } + int id() const { return id_; } protected: /// Execute one operation on one executor. - virtual void Exec(int operation, int executor) = 0; + virtual void DoExec(function<void(Context*)>&& fn, int executor) = 0; + + virtual void CopyToFrom(void* dst, const void* src, size_t nBytes, + CopyDirection direction, Context* ctx) = 0; /// Allocate device memory. virtual void* Malloc(int size) = 0; @@ -105,31 +114,39 @@ class Device { protected: int id_ = 0; - Scheduler* scheduler_ = nullptr; - VirtualMemory* vm_ = nullptr; - /// could be kCudnn - LibType nn_lib_; + int num_executors_ = 0; + unsigned seed_ = 0; + // Scheduler* scheduler_ = nullptr; + // VirtualMemory* vm_ = nullptr; /// could be kCpp, kCuda, kOpencl - LibType device_lib_; + DeviceType device_type_; // SafeQueue<Operation> op_queue_; // SafeQueue<Operation> op_log_; /// The host device - Context ctx_; Device* host_; }; -// Implement Device using Cpp libs. + +// Implement Device functions using cpp. class CppDevice : public Device { public: - CppDevice(int id, int num_executors); - - void Exec(int operation, int executor) override; + CppDevice(int id, int num_executors = 1, + string scheduler = "sync", string vm = "gc-only"); + void SetRandSeed(unsigned seed) override; protected: + void DoExec(function<void(Context*)>&& fn, int executor) override; + + void CopyToFrom(void* dst, const void* src, size_t nBytes, + CopyDirection direction, Context* ctx) override; + /// Allocate cpu memory. void* Malloc(int size) override; /// Free cpu memory. void Free(void* ptr) override; + + protected: + Context ctx_; }; /// a singleton CppDevice as the host for all devices. @@ -138,9 +155,56 @@ extern CppDevice hostDeviceSingleton; // Implement Device using OpenCL libs. // class OpenclDevice : public Device { }; -// Implement Device using Cuda libs for Nvidia GPUs. -// class CudaDevice : public Device { }; +#ifdef USE_CUDA +// Implement Device using cuda. +class CudaDevice : public Device { + public: + ~CudaDevice(); + CudaDevice(int id, int num_executors = 1, string scheduler = "sync", + string vm = "gc-only"); + + void SetRandSeed(unsigned seed) override; + static void DeviceQuery(); + /// This function checks the availability of GPU #device_id. + /// It attempts to create a context on the device by calling cudaFree(0). + /// cudaSetDevice() alone is not sufficient to check the availability. + /// It lazily records device_id, however, does not initialize a + /// context. So it does not know if the host thread has the permission to use + /// the device or not. + /// + /// In a shared environment where the devices are set to EXCLUSIVE_PROCESS + /// or EXCLUSIVE_THREAD mode, cudaSetDevice() returns cudaSuccess + /// even if the device is exclusively occupied by another process or thread. + /// Cuda operations that initialize the context are needed to check + /// the permission. cudaFree(0) is one of those with no side effect, + /// except the context initialization. + static bool CheckDevice(const int device_id); + /// This function finds the first available device by checking devices with + /// ordinal from start_id to the highest available value. In the + /// EXCLUSIVE_PROCESS or EXCLUSIVE_THREAD mode, if it succeeds, it also + /// claims the device due to the initialization of the context. + static int FindDevice(const int start_id); + protected: + void DoExec(function<void(Context*)>&& fn, int executor) override; + + void CopyToFrom(void* dst, const void* src, size_t nBytes, + CopyDirection direction, Context* ctx) override; + + /// Allocate cpu memory. + void* Malloc(int size) override; + + /// Free cpu memory. + void Free(void* ptr) override; + + protected: + Context ctx_; +}; + +#endif // USE_CUDA +// Implement a CudaHost device, which used cuda functions for memory +// malloc/free. +// class CudaHost : public Device {} } // namespace singa #endif // SINGA_CORE_DEVICE_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/282712ca/include/singa/core/tensor.h ---------------------------------------------------------------------- diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h index 88a895b..03bf443 100644 --- a/include/singa/core/tensor.h +++ b/include/singa/core/tensor.h @@ -208,20 +208,12 @@ class Tensor { /// Copy 'num' elements of src to dst. /// The first 'src_offset' ('dst_offset') elements will be skipped. -void CopyData(Tensor* dst, +void CopyDataToFrom(Tensor* dst, const Tensor& src, size_t num, size_t src_offset = 0, size_t dst_offset = 0); -/// Copy 'nBytes' bytes of src data to dst. -/// The first 'src_offset' ('dst_offset') bytes will be skipped. -void CopyRawData(Tensor* dst, - const Tensor& src, - size_t nBytes, - size_t src_offset = 0, - size_t dst_offset = 0); - // ==================Simple Linear Algebra Operations========================= Tensor Abs(const Tensor& t); Tensor Exp(const Tensor& t); @@ -279,6 +271,8 @@ template <typename DType> void Div(const Tensor& t, DType x, Tensor* ret); // ================Blas operations============================================ +// We fix the scalar argument type to be float. + // ===== Level 1 // TODO(wangwei) make amax/amin/asum a member function of tensor // void Amax(Tensor, Context* ctx); Get the index of the max value in a vector @@ -289,25 +283,19 @@ void Div(const Tensor& t, DType x, Tensor* ret); /// Do matrix vector multipication or matrix matrix multiplication depdending /// on the Tensor shape. ret = lhs * rhs -template <typename DType> Tensor Mult(const Tensor& lhs, const Tensor& rhs); /// Do matrix vector multipication or matrix matrix multiplication depdending /// on the Tensor shape. ret = lhs * rhs -template <typename DType> void Mult(const Tensor& lhs, const Tensor& rhs, Tensor* ret); /// Do matrix vector multipication or matrix matrix multiplication depdending /// on the Tensor shape. ret = alpha lhs * rhs + beta * ret -template <typename DType> -Tensor Mult(DType alpha, const Tensor& lhs, DType beta, const Tensor& rhs); +Tensor Mult(float alpha, const Tensor& lhs, float beta, const Tensor& rhs); /// Do matrix vector multipication or matrix matrix multiplication depdending /// on the Tensor shape. ret = alpha lhs * rhs + beta * ret -template <typename DType> -void Mult(DType alpha, const Tensor& lhs, DType beta, const Tensor& rhs, +void Mult(float alpha, const Tensor& lhs, float beta, const Tensor& rhs, Tensor* C); -// tempalte<typename DType> T Dot(const Tensor& lhs, const Tensor& rhs); - // ================Random operations========================================== /// For each element x set x = 1 if random() < p; otherwise x = 1. void Bernoulli(float p, Tensor* t); @@ -316,19 +304,6 @@ void Uniform(float low, float high, Tensor* t); /// Fill in Tensor 't' following Gaussian distribution. void Gaussian(float mean, float std, Tensor* t); -// ================Neural Net operations====================================== -/* following API of cudnn, e.g., conv, pool, lrn, batchnorm, softmax -void ConvFwd(const ConvConf& conf, const Tensor& x, const Tensor& w, Tensor* y); -void ConvBwdBias(const ConvConf& conf, const Tensor& dy, Tensor* db); -void ConvBwdFilter(const ConvConf& conf, const Tensor& dy, const Tensor& x, - Tensor* dw); -void ConvBwdData(const ConvConf& conf, const Tensor& dy, const Tensor& w, - Tensor* db); -void PoolFwd(const PoolConf& conf, const Tensor& x, Tensor* y, - Tensor* mask = nullptr); -void PoolBwd(const PoolConf& conf, const Tensor& y, const Tensor& dy, - const Tensor& x, Tensor* dx); -*/ } // namespace singa #endif // SINGA_CORE_TENSOR_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/282712ca/include/singa/utils/cuda.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/cuda.h b/include/singa/utils/cuda.h new file mode 100644 index 0000000..b2bb5c5 --- /dev/null +++ b/include/singa/utils/cuda.h @@ -0,0 +1,94 @@ +// from caffe include/caffe/util/device_alternative.hpp + +#include <cublas_v2.h> +#include <cuda.h> +#include <cuda_runtime.h> + +// +// CUDA macros +// + +// CUDA: various checks for different function calls. +#define CUDA_CHECK(condition) \ + /* Code block avoids redefinition of cudaError_t error */ \ + do { \ + cudaError_t error = condition; \ + CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \ + } while (0) + +#define CUBLAS_CHECK(condition) \ + do { \ + cublasStatus_t status = condition; \ + CHECK_EQ(status, CUBLAS_STATUS_SUCCESS) << " " \ + << cublasGetErrorString(status); \ + } while (0) + +#define CURAND_CHECK(condition) \ + do { \ + curandStatus_t status = condition; \ + CHECK_EQ(status, CURAND_STATUS_SUCCESS) << " " \ + << curandGetErrorString(status); \ + } while (0) + +const char* cublasGetErrorString(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; +#if CUDA_VERSION >= 6000 + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; +#endif +#if CUDA_VERSION >= 6050 + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; +#endif + } + return "Unknown cublas status"; +} + +const char* curandGetErrorString(curandStatus_t error) { + switch (error) { + case CURAND_STATUS_SUCCESS: + return "CURAND_STATUS_SUCCESS"; + case CURAND_STATUS_VERSION_MISMATCH: + return "CURAND_STATUS_VERSION_MISMATCH"; + case CURAND_STATUS_NOT_INITIALIZED: + return "CURAND_STATUS_NOT_INITIALIZED"; + case CURAND_STATUS_ALLOCATION_FAILED: + return "CURAND_STATUS_ALLOCATION_FAILED"; + case CURAND_STATUS_TYPE_ERROR: + return "CURAND_STATUS_TYPE_ERROR"; + case CURAND_STATUS_OUT_OF_RANGE: + return "CURAND_STATUS_OUT_OF_RANGE"; + case CURAND_STATUS_LENGTH_NOT_MULTIPLE: + return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; + case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: + return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; + case CURAND_STATUS_LAUNCH_FAILURE: + return "CURAND_STATUS_LAUNCH_FAILURE"; + case CURAND_STATUS_PREEXISTING_FAILURE: + return "CURAND_STATUS_PREEXISTING_FAILURE"; + case CURAND_STATUS_INITIALIZATION_FAILED: + return "CURAND_STATUS_INITIALIZATION_FAILED"; + case CURAND_STATUS_ARCH_MISMATCH: + return "CURAND_STATUS_ARCH_MISMATCH"; + case CURAND_STATUS_INTERNAL_ERROR: + return "CURAND_STATUS_INTERNAL_ERROR"; + } + return "Unknown curand status"; +} + http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/282712ca/src/core/device/cpp_device.cc ---------------------------------------------------------------------- diff --git a/src/core/device/cpp_device.cc b/src/core/device/cpp_device.cc index 42f95c3..d0e051e 100644 --- a/src/core/device/cpp_device.cc +++ b/src/core/device/cpp_device.cc @@ -18,13 +18,18 @@ #include "singa/core/device.h" namespace singa { CppDevice hostDeviceSingleton(-1, 1); -CppDevice::CppDevice(int id, int num_executors) { - nn_lib_ = kCpp; - device_lib_ = kCpp; - host_ = &hostDeviceSingleton; +CppDevice::CppDevice(int id, int num_executors, string scheduler, + string vm) : Device(id, num_executors, scheduler, vm) { + device_type_ = kCpp; + host_ = nullptr; } -void CppDevice::Exec(int operation, int executor) { +void CppDevice::SetRandSeed(unsigned seed) { + ctx_.random_generator.seed(seed); +} +void CppDevice::DoExec(function<void(Context*)>&& fn, int executor) { + CHECK_EQ(executor, 0); + fn(&ctx_); } void* CppDevice::Malloc(int size) { @@ -35,4 +40,8 @@ void CppDevice::Free(void* ptr) { free(ptr); } +void CppDevice::CopyToFrom(void* dst, const void* src, size_t nBytes, + CopyDirection direction, Context* ctx) { + memcpy(dst, src, nBytes); +} } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/282712ca/src/core/device/cuda_device.cc ---------------------------------------------------------------------- diff --git a/src/core/device/cuda_device.cc b/src/core/device/cuda_device.cc index 76c646e..1f6de60 100644 --- a/src/core/device/cuda_device.cc +++ b/src/core/device/cuda_device.cc @@ -15,10 +15,142 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifdef USE_CUDA +#include <chrono> +#include <cublas_v2.h> +#include <cuda.h> +#include <cuda_runtime.h> +#include <curand.h> + #include "singa/core/device.h" +#include "singa/utils/cuda.h" namespace singa { +const cudaMemcpyKind copyKind[] = {cudaMemcpyHostToHost, cudaMemcpyHostToDevice, + cudaMemcpyDeviceToHost, + cudaMemcpyDeviceToDevice}; + +CudaDevice::~CudaDevice() { + if (ctx_.cublas_handle) + CUBLAS_CHECK(cublasDestroy(ctx_.cublas_handle)); + if (ctx_.curand_generator) + CURAND_CHECK(curandDestroyGenerator(ctx_.curand_generator)); +#ifdef USE_CUDNN + if (ctx_.cudnn_handle) { + auto status = cudnnDestroy(ctx_.cudnn_handle); + CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(status); + } +#endif +} + +CudaDevice::CudaDevice(int id, int num_executors, + string scheduler, string vm) + : Device(id, num_executors, scheduler, vm) { + device_type_ = kCuda; + host_ = nullptr; // TODO(wangwei) add host device + ctx_.stream = NULL; // use the default sync stream + // TODO(wangwei) create one handle for each steam? + CUBLAS_CHECK(cublasCreate(&ctx_.cublas_handle)); + // use curandCreateGeneratorHost for CudaHost device + CURAND_CHECK( + curandCreateGenerator(&ctx_.curand_generator, CURAND_RNG_PSEUDO_DEFAULT)); + auto seed = std::chrono::system_clock::now().time_since_epoch().count(); + SetRandSeed(seed); + // TODO(wangwei) if one generator per stream, then need diff offset per gen? + CURAND_CHECK(curandSetGeneratorOffset(ctx_.curand_generator, 0)); + +#ifdef USE_CUDNN + // TODO(wangwei) create one handle for each stream? + auto status = cudnnCreate(&ctx_.cudnn_handle); + CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(status); +#endif // USE_CUDNN +} + +void CudaDevice::SetRandSeed(unsigned seed) { + CHECK(ctx_.curand_generator); + CURAND_CHECK( + curandSetPseudoRandomGeneratorSeed(ctx_.curand_generator, seed)); +} + +void CudaDevice::DoExec(function<void(Context*)>&& fn, int executor) { + fn(&ctx_); +} + +void CudaDevice::CopyToFrom(void* dst, const void* src, size_t nBytes, + CopyDirection direction, Context* ctx) { + cudaMemcpy(dst, src, nBytes, copyKind[direction]); + // TODO(wangwei) use async copy + // cudaMemcpyAsync(dst, src, nBytes,cudaMemcpyDefault, ctx_.stream); +} +/// Allocate cpu memory. +void* CudaDevice::Malloc(int size) { + void* ptr = nullptr; + cudaMalloc(&ptr, size); + return ptr; +} + + /// Free cpu memory. +void CudaDevice::Free(void* ptr) { + CHECK_NE(ptr, nullptr); + cudaFree(ptr); +} + + +// ==========Following code is from Caffe src/caffe/common.cpp================= + +void CudaDevice::DeviceQuery() { + cudaDeviceProp prop; + int device; + if (cudaSuccess != cudaGetDevice(&device)) { + printf("No cuda device present.\n"); + return; + } + CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); + LOG(INFO) << "Device id: " << device; + LOG(INFO) << "Major revision number: " << prop.major; + LOG(INFO) << "Minor revision number: " << prop.minor; + LOG(INFO) << "Name: " << prop.name; + LOG(INFO) << "Total global memory: " << prop.totalGlobalMem; + LOG(INFO) << "Total shared memory per block: " << prop.sharedMemPerBlock; + LOG(INFO) << "Total registers per block: " << prop.regsPerBlock; + LOG(INFO) << "Warp size: " << prop.warpSize; + LOG(INFO) << "Maximum memory pitch: " << prop.memPitch; + LOG(INFO) << "Maximum threads per block: " << prop.maxThreadsPerBlock; + LOG(INFO) << "Maximum dimension of block: " + << prop.maxThreadsDim[0] << ", " << prop.maxThreadsDim[1] << ", " + << prop.maxThreadsDim[2]; + LOG(INFO) << "Maximum dimension of grid: " + << prop.maxGridSize[0] << ", " << prop.maxGridSize[1] << ", " + << prop.maxGridSize[2]; + LOG(INFO) << "Clock rate: " << prop.clockRate; + LOG(INFO) << "Total constant memory: " << prop.totalConstMem; + LOG(INFO) << "Texture alignment: " << prop.textureAlignment; + LOG(INFO) << "Concurrent copy and execution: " + << (prop.deviceOverlap ? "Yes" : "No"); + LOG(INFO) << "Number of multiprocessors: " << prop.multiProcessorCount; + LOG(INFO) << "Kernel execution timeout: " + << (prop.kernelExecTimeoutEnabled ? "Yes" : "No"); + return; +} + +bool CudaDevice::CheckDevice(const int device_id) { + bool r = ((cudaSuccess == cudaSetDevice(device_id)) && + (cudaSuccess == cudaFree(0))); + // reset any error that may have occurred. + cudaGetLastError(); + return r; +} + +int CudaDevice::FindDevice(const int start_id) { + int count = 0; + CUDA_CHECK(cudaGetDeviceCount(&count)); + for (int i = start_id; i < count; i++) { + if (CheckDevice(i)) return i; + } + return -1; +} } +#endif // USE_CUDA http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/282712ca/src/core/device/device.cc ---------------------------------------------------------------------- diff --git a/src/core/device/device.cc b/src/core/device/device.cc index 33f5bd8..153637c 100644 --- a/src/core/device/device.cc +++ b/src/core/device/device.cc @@ -20,44 +20,53 @@ namespace singa { Device::Device(int id, int num_executors, string scheduler, string vm) - : id_(id) { - scheduler_ = nullptr; - vm_ = nullptr; - ctx_.seed = 0; - ctx_.random_generator = std::mt19937(ctx_.seed); + : id_(id), num_executors_(num_executors) { + // TODO(wangwei) create scheduler and vm. } -void Device::Exec(function<void(Context*)> fn, const vector<Blob*> read_blobs, +void Device::Exec(function<void(Context*)>&& fn, const vector<Blob*> read_blobs, const vector<Blob*> write_blobs, bool use_rand_generator) { - fn(&ctx_); + // TODO(wangwei) execute operations scheduled by the scheduler. + DoExec(std::move(fn), 0); } +// TODO(wangwei) get Blob from the memory manager Blob* Device::NewBlob(int size) { if (size > 0) { - void* ptr = malloc(size); - memset(ptr, 0, size); + void* ptr = Malloc(size); + // memset(ptr, 0, size); return new Blob(ptr, size); } else { return nullptr; } } +// TODO(wangwei) return Blob to the memory manager void Device::FreeBlob(Blob* blob) { if (blob != nullptr) { - free(blob->mutable_data()); + Free(blob->mutable_data()); delete blob; } } -void Device::CopyData(Blob* dst, const Blob& src, int len, int dst_offset, - int src_offset) { - - memcpy(reinterpret_cast<Byte*>(dst->mutable_data()) + dst_offset, - (const Byte*)src.data() + src_offset, len); +void Device::CopyDataToFrom(Blob* dst, Blob* src, size_t nBytes, + CopyDirection direct, int dst_offset, + int src_offset) { + this->Exec( + [this, dst, src, nBytes, direct, dst_offset, src_offset](Context* ctx) { + this->CopyToFrom((Byte*)dst->mutable_data() + dst_offset, + (Byte*)src->data() + src_offset, nBytes, direct, ctx); + }, + {src}, {dst}); } -void Device::CopyDataFromHostPtr(Blob* dst, const void* src, size_t size) { - memcpy(dst->mutable_data(), src, size); +void Device::CopyDataFromHostPtr(Blob* dst, const void* src, size_t nBytes, + size_t dst_offset) { + auto direct = device_type_ == kCpp ? kHostToHost : kHostToDevice; + void* dstptr = (Byte*)dst->mutable_data() + dst_offset; + Exec([this, dstptr, src, nBytes, + direct](Context* ctx) { CopyToFrom(dstptr, src, nBytes, direct, ctx); }, + {}, {dst}); } void Device::Sync() {} } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/282712ca/src/core/tensor/tensor.cc ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index 0e5570d..339262e 100644 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -115,16 +115,17 @@ void Tensor::ToHost() { ToDevice(device_->host()); } -template<typename DType> +template <typename DType> void Tensor::CopyDataFromHostPtr(const DType* src, size_t num) { - CHECK_EQ(sizeof(DType), SizeOf(data_type_)) << "data_type is " - << DataType_Name(data_type_) - << " user given type is of size " - << sizeof(DType); - if (src != nullptr) - device_->CopyDataFromHostPtr(blob(), src, sizeof(DType) * num); - else + CHECK_EQ(sizeof(DType), SizeOf(data_type_)) + << "data_type is " << DataType_Name(data_type_) + << " user given type is of size " << sizeof(DType); + if (src != nullptr) { + auto direction = device_->type() == kCpp ? kHostToHost : kHostToDevice; + device_->CopyDataFromHostPtr(blob(), src, sizeof(DType) * num, direction); + } else { LOG(WARNING) << "Copy data from null host ptr"; + } } template void Tensor::CopyDataFromHostPtr(const float* src, size_t num); @@ -133,7 +134,7 @@ void Tensor::CopyData(const Tensor& src) { CHECK(blob_ != nullptr); // Do copy only if the src's blob is already initialized. if (src.blob_ != nullptr) { - singa::CopyData(this, src, Size(), 0, 0); + singa::CopyDataToFrom(this, src, Size(), 0, 0); } } @@ -197,38 +198,32 @@ GenUnaryScalarArgMemberFunction(operator*=, EltwiseMult); GenUnaryScalarArgMemberFunction(operator/=, Div); // ====================Tensor Operations======================================= -void CopyData(Tensor* dst, - const Tensor& src, - size_t num, - size_t dst_offset, - size_t src_offset) { - CHECK_GE(src.Size(), src_offset + num); - CHECK_GE(dst->Size(), dst_offset + num); +void CopyDataToFrom(Tensor* dst, const Tensor& src, size_t num, + size_t dst_offset, size_t src_offset) { auto width = SizeOf(src.data_type()); CHECK_EQ(width, SizeOf(dst->data_type())); - CopyRawData(dst, src, num * width, dst_offset * width, src_offset * width); -} - -void CopyRawData(Tensor* dst, - const Tensor& src, - size_t nBytes, - size_t dst_offset, - size_t src_offset) { + size_t nBytes = num * width; + dst_offset *= width; + src_offset *= width; CHECK_GE(src.MemSize(), src_offset + nBytes); CHECK_GE(dst->MemSize(), dst_offset + nBytes); - Device* src_dev = src.device(), *dst_dev = dst->device(); - Blob* src_blob = src.blob(), *dst_blob = dst->blob(); - if (dst_dev->device_lib() != src_dev->device_lib()) { + + Device *src_dev = src.device(), *dst_dev = dst->device(); + Blob *from = src.blob(), *to = dst->blob(); + if (dst_dev->type() != src_dev->type()) { // let the none cpp device conduct copy op - if (dst_dev->device_lib() == kCpp) { - src_dev->CopyData(dst_blob, *src_blob, nBytes, dst_offset, src_offset); - } else if (src_dev->device_lib() == kCpp) { - dst_dev->CopyData(dst_blob, *src_blob, nBytes, dst_offset, src_offset); + if (dst_dev->type() == kCpp) { + src_dev->CopyDataToFrom(to, from, nBytes, kDeviceToHost, dst_offset, + src_offset); + } else if (src_dev->type() == kCpp) { + dst_dev->CopyDataToFrom(to, from, nBytes, kHostToDevice, dst_offset, + src_offset); } else { LOG(FATAL) << "Not support mem copy betwee Cuda and OpenCL device"; } } else { - src_dev->CopyData(dst_blob, *src_blob, nBytes, dst_offset, src_offset); + auto direct = src_dev->type() == kCpp ? kHostToHost : kDeviceToDevice; + src_dev->CopyDataToFrom(to, from, nBytes, direct, dst_offset, src_offset); } } //============================================================================ @@ -257,52 +252,46 @@ void CopyRawData(Tensor* dst, } \ } while (0) -/// typedef DType and Lib according to values of type and lib respectively. -/// type is from DataType, and lib is from LibType. -/// DType and Lib would be used in __VA_ARGS__. -#define TYPE_LIB_SWITCH(dtype, DType, ltype, Lib, ...) \ +/// typedef DType and Dev according to values of type and lib respectively. +/// type is from DataType, and lib is from DevType. +/// DType and Dev would be used in __VA_ARGS__. +#define TYPE_LIB_SWITCH(dtype, DType, dev, Dev, ...) \ do { \ const int _SwitchShift = 3; \ - int _SwitchHash = ((dtype) << _SwitchShift) + (ltype); \ + int _SwitchHash = ((dtype) << _SwitchShift) + (dev); \ switch (_SwitchHash) { \ case ((kFloat32 << _SwitchShift) + kCuda): { \ typedef float DType; \ - typedef lib::Cuda Lib; \ - { __VA_ARGS__ } \ - break; \ - } \ - case ((kFloat32 << _SwitchShift) + kCudnn): { \ - typedef float DType; \ - typedef lib::Cudnn Lib; \ + typedef lib::Cuda Dev; \ { __VA_ARGS__ } \ break; \ } \ case ((kFloat32 << _SwitchShift) + kCpp): { \ typedef float DType; \ - typedef lib::Cpp Lib; \ + typedef lib::Cpp Dev; \ { __VA_ARGS__ } \ break; \ } \ case ((kFloat32 << _SwitchShift) + kOpencl): { \ typedef float DType; \ - typedef lib::Opencl Lib; \ + typedef lib::Opencl Dev; \ { __VA_ARGS__ } \ break; \ } \ default: \ LOG(FATAL) << "Unknown combination of data type " \ << DataType_Name(dtype) << " and library " \ - << LibType_Name(ltype); \ + << DeviceType_Name(dev); \ } \ } while (0) #define EltwiseUnaryTensorFn(fn, t, ret) \ do { \ - TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->device_lib(), Lib, { \ + TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->type(), Dev, { \ ret->device()->Exec( \ [t, ret](Context* ctx) { \ - fn<DType, Lib>(t.Size(), t.blob(), ret->blob(), ctx); \ + fn<DType, Dev>(t.Size(), t.blob(), ret->blob(), ctx); \ }, \ {t.blob()}, {ret->blob()}); \ }); \ @@ -340,10 +329,10 @@ void Softmax(const Tensor& t, Tensor* ret, int axis) { CHECK_EQ(size % nrow, 0) << "Size = " << size << " nrow = " << nrow; ncol = size / nrow; } - TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->device_lib(), Lib, { + TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->type(), Dev, { ret->device()->Exec( [nrow, ncol, t, ret](Context* ctx) { - Softmax<DType, Lib>(nrow, ncol, t.blob(), ret->blob(), ctx); + Softmax<DType, Dev>(nrow, ncol, t.blob(), ret->blob(), ctx); }, {t.blob()}, {ret->blob()}); }); @@ -351,11 +340,11 @@ void Softmax(const Tensor& t, Tensor* ret, int axis) { #define EltwiseBinaryTensorFn(fn, lhs, rhs, ret) \ do { \ - TYPE_LIB_SWITCH(lhs.data_type(), DType, lhs.device()->device_lib(), Lib, { \ + TYPE_LIB_SWITCH(lhs.data_type(), DType, lhs.device()->type(), Dev, { \ CHECK_EQ(sizeof(DType), SizeOf(rhs.data_type())); \ ret->device()->Exec( \ [lhs, rhs, ret](Context* ctx) { \ - fn<DType, Lib>(lhs.Size(), lhs.blob(), rhs.blob(), ret->blob(), \ + fn<DType, Dev>(lhs.Size(), lhs.blob(), rhs.blob(), ret->blob(), \ ctx); \ }, \ {lhs.blob(), rhs.blob()}, {ret->blob()}); \ @@ -378,17 +367,17 @@ GenBinaryTensorFunction(operator*, EltwiseMult); GenBinaryTensorFunction(operator/, Div); GenBinaryTensorFunction(Pow, Pow); -#define EltwiseTensorScalarFn(fn, t, x, ret) \ - do { \ - TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->device_lib(), Lib, { \ - static_assert(std::is_same<SType, DType>::value, \ - "The Scalar type must match the Tensor data type"); \ - ret->device()->Exec( \ - [t, x, ret](Context* ctx) { \ - fn<DType, Lib>(t.Size(), t.blob(), x, ret->blob(), ctx); \ - }, \ - {t.blob()}, {ret->blob()}); \ - }); \ +#define EltwiseTensorScalarFn(fn, t, x, ret) \ + do { \ + TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->type(), Dev, { \ + static_assert(std::is_same<SType, DType>::value, \ + "The Scalar type must match the Tensor data type"); \ + ret->device()->Exec( \ + [t, x, ret](Context* ctx) { \ + fn<DType, Dev>(t.Size(), t.blob(), x, ret->blob(), ctx); \ + }, \ + {t.blob()}, {ret->blob()}); \ + }); \ } while (0) #define GenTensorScalarFunction(op, fn) \ @@ -412,43 +401,33 @@ GenTensorScalarFunction(operator/, Div); GenTensorScalarFunction(Pow, Pow); // ================Blas operations============================================ -template <typename DType> Tensor Mult(const Tensor& lhs, const Tensor& rhs) { Tensor ret(lhs.shape(), lhs.device(), lhs.data_type()); - Mult<DType>(lhs, rhs, &ret); + Mult(lhs, rhs, &ret); return ret; } -template Tensor Mult<float>(const Tensor& lhs, const Tensor& rhs); -template <typename DType> void Mult(const Tensor& lhs, const Tensor& rhs, Tensor* ret) { - Mult(DType(1), lhs, DType(1), rhs, ret); + Mult(1, lhs, 1, rhs, ret); } -template void Mult<float>(const Tensor& lhs, const Tensor& rhs, Tensor* ret); -template <typename DType> -Tensor Mult(DType alpha, const Tensor& A, DType beta, const Tensor& B) { +Tensor Mult(float alpha, const Tensor& A, float beta, const Tensor& B) { Tensor ret(A.shape(), A.device(), A.data_type()); - Mult<DType>(alpha, A, beta, B, &ret); + Mult(alpha, A, beta, B, &ret); return ret; } -template Tensor Mult<float>(float alpha, const Tensor& lhs, float beta, - const Tensor& rhs); -template <typename SType> -void Mult(SType alpha, const Tensor& A, SType beta, const Tensor& B, +void Mult(float alpha, const Tensor& A, float beta, const Tensor& B, Tensor* C) { CHECK_EQ(A.shape().size(), 2u); bool transA = A.transpose(); size_t m = transA ? A.shape()[1] : A.shape()[0], n = 0; if (B.shape().size() == 1u) { n = C->Size(); - TYPE_LIB_SWITCH(A.data_type(), DType, A.device()->device_lib(), Lib, { - static_assert(std::is_same<SType, DType>::value, - "The scalar type must be the same as the tensor data type"); + TYPE_LIB_SWITCH(A.data_type(), DType, A.device()->type(), Dev, { C->device()->Exec( [transA, m, n, alpha, A, beta, B, C](Context* ctx) { - GEMV<DType, Lib>(transA, m, n, alpha, A.blob(), B.blob(), beta, + GEMV<DType, Dev>(transA, m, n, alpha, A.blob(), B.blob(), beta, C->blob(), ctx); }, {A.blob(), B.blob()}, {C->blob()}); @@ -461,61 +440,42 @@ void Mult(SType alpha, const Tensor& A, SType beta, const Tensor& B, CHECK_EQ(C->shape()[0], m); CHECK_EQ(A.Size(), m * k); CHECK_EQ(B.Size(), n * k); - TYPE_LIB_SWITCH(A.data_type(), DType, A.device()->device_lib(), Lib, { - static_assert(std::is_same<SType, DType>::value, - "The scalar type must be the same as the tensor data type"); + TYPE_LIB_SWITCH(A.data_type(), DType, A.device()->type(), Dev, { C->device()->Exec( [transA, transB, m, n, k, alpha, A, beta, B, C](Context* ctx) { - GEMM<DType, Lib>(transA, transB, m, n, k, alpha, A.blob(), B.blob(), + GEMM<DType, Dev>(transA, transB, m, n, k, alpha, A.blob(), B.blob(), beta, C->blob(), ctx); }, {A.blob(), B.blob()}, {C->blob()}); }); } } -template void Mult<float>(float alpha, const Tensor& lhs, float beta, - const Tensor& rhs, Tensor* ret); - -// ================Neural Net operations====================================== -/* -void Conv(const OpConf* conf, const Tensor& input, const Tensor& W, - const Tensor& b, Tensor* ret) { - TYPE_LIB_SWITCH(input.data_type(), DType, input.device()->nn_lib(), Lib, { - ret->device()->Exec( - [conf, input, W, b, ret](Context* ctx) { - Conv<DType, Lib>(conf, input.blob(), W.blob(), b.blob(), ret->blob(), - ctx); - }, - {input.blob(), W.blob(), b.blob()}, {ret->blob()}); - }); -} -*/ void Bernoulli(float p, Tensor* t) { - TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->nn_lib(), Lib, { + TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->type(), Dev, { t->device()->Exec( [p, t](Context* ctx) { - Bernoulli<DType, Lib>(t->Size(), p, t->blob(), ctx); + Bernoulli<DType, Dev>(t->Size(), p, t->blob(), ctx); }, {}, {t->blob()}, true); }); } void Uniform(float low, float high, Tensor* t) { - TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->nn_lib(), Lib, { + TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->type(), Dev, { t->device()->Exec( [low, high, t](Context* ctx) { - Uniform<DType, Lib>(t->Size(), low, high, t->blob(), ctx); + Uniform<DType, Dev>(t->Size(), low, high, t->blob(), ctx); }, {}, {t->blob()}, true); }); } void Gaussian(float mean, float std, Tensor* t) { - TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->nn_lib(), Lib, { + TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->type(), Dev, { t->device()->Exec( [mean, std, t](Context* ctx) { - Gaussian<DType, Lib>(t->Size(), mean, std, t->blob(), ctx); + Gaussian<DType, Dev>(t->Size(), mean, std, t->blob(), ctx); }, {}, {t->blob()}, true); }); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/282712ca/src/proto/core.proto ---------------------------------------------------------------------- diff --git a/src/proto/core.proto b/src/proto/core.proto index f366ed0..f99aba4 100644 --- a/src/proto/core.proto +++ b/src/proto/core.proto @@ -30,10 +30,17 @@ enum DataType { kNumDataType = 5; } -enum LibType { +enum DeviceType { kCpp = 0; kCuda = 1; kOpencl = 2; - kCudnn = 3; - kNumLibType = 4; + kNumDeviceType = 4; +} + +enum CopyDirection { + kHostToHost = 0; + kHostToDevice = 1; + kDeviceToHost = 2; + kDeviceToDevice = 3; + kNumDirection = 4; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/282712ca/test/singa/test_cpp_device.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_cpp_device.cc b/test/singa/test_cpp_device.cc new file mode 100644 index 0000000..d2c0149 --- /dev/null +++ b/test/singa/test_cpp_device.cc @@ -0,0 +1,71 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#include "gtest/gtest.h" +#include "singa/core/device.h" +#include "singa/proto/core.pb.h" + +using singa::CppDevice; +using singa::Blob; +TEST(CppDevice, Constructor) { + CppDevice dev(0, 1); + EXPECT_EQ(0, dev.id()); +} + +TEST(CppDevice, MemoryMallocFree) { + CppDevice dev(0, 1); + Blob* b = dev.NewBlob(4); + EXPECT_NE(nullptr, b); + EXPECT_EQ(4, b->size()); + dev.FreeBlob(b); +} + +TEST(CppDevice, Exec) { + CppDevice dev(0, 1); + Blob* b = dev.NewBlob(4); + int x = 1, y =3, z = 0; + dev.Exec([x, y, &z](singa::Context *ctx) { + z = x + y; + }, {b}, {b}, false); + EXPECT_EQ(x + y, z); +} + +TEST(CppDevice, CopyData) { + CppDevice dev(0, 1); + Blob* b = dev.NewBlob(4); + char s[] = {'a', 'b', 'c', 'x'}; + dev.CopyDataFromHostPtr(b, s, 4); + const char* bstr = static_cast<const char*>(b->data()); + EXPECT_EQ('a', bstr[0]); + EXPECT_EQ('b', bstr[1]); + EXPECT_EQ('x', bstr[3]); + + Blob* c = dev.NewBlob(4); + dev.CopyDataToFrom(c, b, 4, singa::kHostToHost, 0, 0); + const char* cstr = static_cast<const char*>(c->data()); + + EXPECT_EQ('a', cstr[0]); + EXPECT_EQ('b', cstr[1]); + EXPECT_EQ('x', cstr[3]); + dev.FreeBlob(b); + dev.FreeBlob(c); +} + http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/282712ca/test/singa/test_tensor_math.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_tensor_math.cc b/test/singa/test_tensor_math.cc index 51e7cfb..ccd91a0 100644 --- a/test/singa/test_tensor_math.cc +++ b/test/singa/test_tensor_math.cc @@ -43,21 +43,7 @@ TEST_F(TestTensorMath, MemberAddTensor) { EXPECT_FLOAT_EQ(6.1f, dptr2[2]); EXPECT_FLOAT_EQ(12.1f, dptr2[5]); } -/* -TEST(TensorClass, SubTensor) { - Tensor a(Shape{2,3}), b(Shape{6}); - float x[]={1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; - float y[]={1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f}; - a.CopyDataFromHostPtr(x, 6); - b.CopyDataFromHostPtr(y, 6); - b -= a; - const float* dptr = b.data<float>(); - EXPECT_FLOAT_EQ(0.1f, dptr[0]); - EXPECT_FLOAT_EQ(0.1f, dptr[1]); - EXPECT_FLOAT_EQ(0.1f, dptr[2]); - EXPECT_FLOAT_EQ(0.1f, dptr[5]); -} -*/ + TEST_F(TestTensorMath, AddTensors) { Tensor ret(a.shape(), a.device(), a.data_type());
