SINGA-176 - Add loss and metric base classes Pass tests for MSE and Accuracy
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/a1c3437c Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/a1c3437c Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/a1c3437c Branch: refs/heads/dev Commit: a1c3437c34b6f613911d8b7ef9f11f483099fc63 Parents: 668ae16 Author: Wei Wang <[email protected]> Authored: Thu May 26 14:03:05 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Thu May 26 14:11:18 2016 +0800 ---------------------------------------------------------------------- include/singa/core/device.h | 7 +++--- src/core/device/cpp_cpu.cc | 4 +++- src/core/device/cuda_gpu.cc | 2 +- src/core/device/device.cc | 1 + src/core/tensor/math_kernel.cu | 34 +++++++++++++++++++++++++++- src/core/tensor/math_kernel.h | 6 +++++ src/core/tensor/tensor_math_cpp.h | 12 +++++++++- src/core/tensor/tensor_math_cuda.h | 40 ++++++++++++++++++++++++++------- src/model/loss/mse.h | 6 ++--- test/singa/test_mse.cc | 24 +++++++++++++++----- 10 files changed, 111 insertions(+), 25 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a1c3437c/include/singa/core/device.h ---------------------------------------------------------------------- diff --git a/include/singa/core/device.h b/include/singa/core/device.h index 23c2431..a4b3f6d 100644 --- a/include/singa/core/device.h +++ b/include/singa/core/device.h @@ -33,7 +33,6 @@ using std::vector; using std::string; using std::function; namespace singa { - /// Allocate memory and execute Tensor operations. /// There are three types of devices distinguished by their programming /// languages, namely cpp, cuda and opencl. @@ -76,8 +75,7 @@ class Device { return lang_; } - /// TODO(wangwei) remove it? - Device* host() const { return host_; } + Device* host() const { return host_;} int id() const { return id_; } @@ -135,6 +133,7 @@ class CppCPU : public Device { /// a singleton CppDevice as the host for all devices. extern CppCPU defaultDevice; + // Implement Device using OpenCL libs. // class OpenclDevice : public Device { }; @@ -143,7 +142,7 @@ extern CppCPU defaultDevice; class CudaGPU : public Device { public: ~CudaGPU(); - CudaGPU(int id = -1, int num_executors = 1, string scheduler = "sync", + CudaGPU(int id = 0, int num_executors = 1, string scheduler = "sync", string vm = "gc-only"); void SetRandSeed(unsigned seed) override; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a1c3437c/src/core/device/cpp_cpu.cc ---------------------------------------------------------------------- diff --git a/src/core/device/cpp_cpu.cc b/src/core/device/cpp_cpu.cc index 3287911..28b0da4 100644 --- a/src/core/device/cpp_cpu.cc +++ b/src/core/device/cpp_cpu.cc @@ -33,7 +33,9 @@ void CppCPU::DoExec(function<void(Context*)>&& fn, int executor) { } void* CppCPU::Malloc(int size) { - return malloc(size); + void *ptr = malloc(size); + memset(ptr, 0, size); + return ptr; } void CppCPU::Free(void* ptr) { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a1c3437c/src/core/device/cuda_gpu.cc ---------------------------------------------------------------------- diff --git a/src/core/device/cuda_gpu.cc b/src/core/device/cuda_gpu.cc index 59a5f45..0ba05fb 100644 --- a/src/core/device/cuda_gpu.cc +++ b/src/core/device/cuda_gpu.cc @@ -50,7 +50,6 @@ CudaGPU::CudaGPU(int id, int num_executors, if (id == -1) id = FindDevice(0); lang_ = kCuda; - host_ = nullptr; // TODO(wangwei) add host device ctx_.stream = NULL; // use the default sync stream // TODO(wangwei) create one handle for each steam? CUDA_CHECK(cudaSetDevice(FindDevice(0))); @@ -91,6 +90,7 @@ void CudaGPU::CopyToFrom(void* dst, const void* src, size_t nBytes, void* CudaGPU::Malloc(int size) { void* ptr = nullptr; CUDA_CHECK(cudaMalloc(&ptr, size)); + CUDA_CHECK(cudaMemset(ptr, 0, size)); return ptr; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a1c3437c/src/core/device/device.cc ---------------------------------------------------------------------- diff --git a/src/core/device/device.cc b/src/core/device/device.cc index cd860db..ede3fda 100644 --- a/src/core/device/device.cc +++ b/src/core/device/device.cc @@ -22,6 +22,7 @@ namespace singa { Device::Device(int id, int num_executors, string scheduler, string vm) : id_(id), num_executors_(num_executors) { // TODO(wangwei) create scheduler and vm. + host_ = &defaultDevice; } void Device::Exec(function<void(Context*)>&& fn, const vector<Blob*> read_blobs, http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a1c3437c/src/core/tensor/math_kernel.cu ---------------------------------------------------------------------- diff --git a/src/core/tensor/math_kernel.cu b/src/core/tensor/math_kernel.cu index 30863a1..e67ea7b 100644 --- a/src/core/tensor/math_kernel.cu +++ b/src/core/tensor/math_kernel.cu @@ -147,7 +147,21 @@ __global__ void kernel_add_vec_row(const float *src_vec_data, des_mat_data[index] = src_mat_data[index] + src_vec_data[i]; } } +__global__ void kernel_add(const float *src1, const float *src2, float*out, int n) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + for (; index < n; index += num_threads) { + out[index] = src1[index] + src2[index]; + } +} +__global__ void kernel_sub(const float *src1, const float *src2, float*out, int n) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + for (; index < n; index += num_threads) { + out[index] = src1[index] - src2[index]; + } +} __global__ void kernel_exp(const float *src_data, float *des_data, int n) { int index = blockIdx.x * blockDim.x + threadIdx.x; int num_threads = blockDim.x * gridDim.x; @@ -275,6 +289,15 @@ __global__ void kernel_mult(const float *src_data_a, const float *src_data_b, } } +__global__ void kernel_mult(const float *src_data_a, const float x, + float *des_data, int n) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + for (; index < n; index += num_threads) { + des_data[index] = src_data_a[index] * x; + } +} + __global__ void kernel_div(const float *src_data_a, const float *src_data_b, float *des_data, int n) { int index = blockIdx.x * blockDim.x + threadIdx.x; @@ -346,7 +369,12 @@ void add_row(int rows, int cols, int stride, const float *in_row, kernel_add_vec_row<<<num_blocks, threads_per_block>>>(in_row, in_mat, out, rows, cols, stride); } - +void add(int n, const float *a, const float *b, float *out) { + kernel_add<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(a, b, out, n); +} +void sub(int n, const float *a, const float *b, float *out) { + kernel_sub<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(a, b, out, n); +} void exp(int n, const float *in, float *out) { kernel_exp<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n); } @@ -407,6 +435,10 @@ void mult(int n, const float *a, const float *b, float *out) { kernel_mult<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(a, b, out, n); } +void mult(int n, const float *a, const float x, float *out) { + kernel_mult<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(a, x, out, n); +} + void div(int n, const float *a, const float *b, float *out) { kernel_div<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(a, b, out, n); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a1c3437c/src/core/tensor/math_kernel.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/math_kernel.h b/src/core/tensor/math_kernel.h index f5da772..5367f4a 100644 --- a/src/core/tensor/math_kernel.h +++ b/src/core/tensor/math_kernel.h @@ -44,6 +44,10 @@ void sum_col(int rows, int cols, int stride, const float *in, float *out); void add_row(int rows, int cols, int stride, const float *in_row, const float *in_mat, float *out); +void add(int n, const float *a, const float *b, float *out); + +void sub(int n, const float *a, const float *b, float *out); + void exp(int n, const float *in, float *out); void log(int n, const float *in, float *out); @@ -74,6 +78,8 @@ void pow(int n, const float *a, const float *b, float *out); void mult(int n, const float *a, const float *b, float *out); +void mult(int n, const float *a, const float x, float *out); + void div(int n, const float *a, const float *b, float *out); void set_value(int n, float v, float *out); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a1c3437c/src/core/tensor/tensor_math_cpp.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h index c584b69..7dc35c9 100644 --- a/src/core/tensor/tensor_math_cpp.h +++ b/src/core/tensor/tensor_math_cpp.h @@ -47,7 +47,17 @@ void Add<float, lang::Cpp>(int count, const Blob* lhs, const Blob* rhs, } } - +template <> +void Sub<float, lang::Cpp>(int count, const Blob* lhs, const Blob* rhs, + Blob* ret, Context* ctx) { + // CHECK_EQ(ctx->stream, nullptr); + float* dptr = static_cast<float*>(ret->mutable_data()); + const float* lptr = static_cast<const float*>(lhs->data()); + const float* rptr = static_cast<const float*>(rhs->data()); + for (int i = 0; i < count; i++) { + dptr[i] = lptr[i] - rptr[i]; + } +} // sum all elements of input into ret // TODO(wangwei) optimize using omp template <> http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a1c3437c/src/core/tensor/tensor_math_cuda.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h index 2e497d2..12fc58e 100644 --- a/src/core/tensor/tensor_math_cuda.h +++ b/src/core/tensor/tensor_math_cuda.h @@ -26,17 +26,41 @@ namespace singa { +// TODO(wangwei) optimize using stream template<> void Add<float, lang::Cuda>(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) { - /* - cublasSetStream(ctx->cublas_handle, ctx->stream); - const float* lptr = static_cast<const float*>(lhs->data()); - const float* rptr = static_cast<const float*>(rhs->data()); - float* ptr = static_cast<float*>(ret->mutable_data()); - cublasScopy(ctx->cublas_handle, count, lptr, 1, ptr, 1); - cublasSaxpy(ctx->cublas_handle, 1.0f, rptr, 1, ptr, 1); - */ + const float* a = static_cast<const float*> (lhs->data()); + const float* b = static_cast<const float*> (rhs->data()); + float* c = static_cast<float*> (ret->mutable_data()); + cuda::add(count, a, b, c); +} + +// TODO(wangwei) optimize using stream +template<> +void Sub<float, lang::Cuda>(int count, const Blob* lhs, const Blob* rhs, + Blob* ret, Context* ctx) { + const float* a = static_cast<const float*> (lhs->data()); + const float* b = static_cast<const float*> (rhs->data()); + float* c = static_cast<float*> (ret->mutable_data()); + cuda::sub(count, a, b, c); +} + +template <> +void EltwiseMult<float, lang::Cuda>(int count, const Blob* input, float x, + Blob* ret, Context* ctx) +{ + float* dptr = static_cast<float*>(ret->mutable_data()); + const float* lptr = static_cast<const float*>(input->data()); + cuda::mult(count, lptr, x, dptr); +} +// TODO(wangwei) optimize using stream +template <> +void Square<float, lang::Cuda>(int count, const Blob* input, Blob* ret, + Context* ctx) { + const float* in = static_cast<const float*>(input->data()); + float* out = static_cast<float*>(ret->mutable_data()); + cuda::square(count, in, out); } // sum all elements of input into ret // TODO(wangwei) optimize using stream http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a1c3437c/src/model/loss/mse.h ---------------------------------------------------------------------- diff --git a/src/model/loss/mse.h b/src/model/loss/mse.h index 5799f13..1a022f9 100644 --- a/src/model/loss/mse.h +++ b/src/model/loss/mse.h @@ -51,13 +51,13 @@ Tensor MSE::Forward(const Tensor& prediction, const Tensor& target) { t.Reshape(Shape{batchsize, dim}); buf_.push(t); // TODO(wangwei) use CastType for operator/ - return Sum(Square(t), 1); + return Sum(Square(t), 1) * 0.5f; } Tensor MSE::Backward() { - const Tensor& ret = buf_.top(); + Tensor ret = buf_.top(); buf_.pop(); - return ret / (1.0f * ret.shape().at(0)); + return ret * (1.0f / ret.shape().at(0)); } } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a1c3437c/test/singa/test_mse.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_mse.cc b/test/singa/test_mse.cc index 9056176..3ee6bf8 100644 --- a/test/singa/test_mse.cc +++ b/test/singa/test_mse.cc @@ -44,8 +44,14 @@ TEST_F(TestMSE, CppForward) { const Tensor& loss = mse.Forward(p, t); auto ldat = loss.data<const float*>(); - EXPECT_FLOAT_EQ(ldat[0], 0.005); - EXPECT_FLOAT_EQ(ldat[1], 0); + for (size_t i = 0, k = 0; i < loss.Size(); i++) { + float l = 0.f; + for (size_t j = 0; j < p.Size() / loss.Size(); j++) { + l += (pdat[k] - tdat[k]) * (pdat[k] - tdat[k]); + k++; + } + EXPECT_FLOAT_EQ(ldat[i], 0.5 * l); + } } TEST_F(TestMSE, CudaForward) { @@ -58,8 +64,14 @@ TEST_F(TestMSE, CudaForward) { loss.ToHost(); auto ldat = loss.data<const float*>(); - for (size_t i = 0; i < loss.Size(); i++) - EXPECT_FLOAT_EQ(ldat[i], 0.5 * (pdat[i] - tdat[i]) * (pdat[i] - tdat[i])); + for (size_t i = 0, k = 0; i < loss.Size(); i++) { + float l = 0.f; + for (size_t j = 0; j < p.Size() / loss.Size(); j++) { + l += (pdat[k] - tdat[k]) * (pdat[k] - tdat[k]); + k++; + } + EXPECT_FLOAT_EQ(ldat[i], 0.5 * l); + } } TEST_F(TestMSE, CppBackward) { @@ -70,7 +82,7 @@ TEST_F(TestMSE, CppBackward) { auto gdat = grad.data<const float*>(); for (size_t i = 0; i < grad.Size(); i++) - EXPECT_FLOAT_EQ(gdat[i], pdat[i] - tdat[i]); + EXPECT_FLOAT_EQ(gdat[i], (1.0f / p.shape().at(0)) * (pdat[i] - tdat[i])); } TEST_F(TestMSE, CudaBackward) { @@ -84,5 +96,5 @@ TEST_F(TestMSE, CudaBackward) { auto gdat = grad.data<const float*>(); for (size_t i = 0; i < grad.Size(); i++) - EXPECT_FLOAT_EQ(gdat[i], pdat[i] - tdat[i]); + EXPECT_FLOAT_EQ(gdat[i], (1.0f / p.shape().at(0)) * (pdat[i] - tdat[i])); }
