SINGA-180 Add Activation layer and Softmax layer Add cpu and cudnn implementation for activation and softmax layer.
Note: activation layer currently support sigmoid/tanh function and relu forward computation. Remove tensor softmax function. Instead, use tensor op(*) and function(Sum) to impletment softmax function. Add test files for activation and softmax layer. Add Element-wise implementation for activation functions (relu/tanh/sigmoid). Add tensor scaler comparison function (<, <=, >, >=), i.e., to compare a tensor with a constant. Add implementation for tensor math functions (exp, log, pow). Add functions for matrix op vector, where op is multiply and div. Pass all tests. Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/3e2507b7 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/3e2507b7 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/3e2507b7 Branch: refs/heads/dev Commit: 3e2507b7af8c4fe3746f3156f29eba99a30e546f Parents: 2dac380 Author: jixin <[email protected]> Authored: Fri May 27 22:03:35 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Tue May 31 22:08:31 2016 +0800 ---------------------------------------------------------------------- include/singa/core/tensor.h | 107 +++++++++++++++++----- src/core/tensor/math_kernel.cu | 132 +++++++++++++++++---------- src/core/tensor/math_kernel.h | 6 +- src/core/tensor/tensor.cc | 152 ++++++++++++++++--------------- src/core/tensor/tensor_math.h | 47 +++++++++- src/core/tensor/tensor_math_cpp.h | 148 ++++++++++++++++++++++++------ src/core/tensor/tensor_math_cuda.h | 54 +++++++---- src/model/layer/activation.cc | 67 ++++++++++++++ src/model/layer/activation.h | 51 +++++++++++ src/model/layer/cudnn_activation.cc | 115 +++++++++++++++++++++++ src/model/layer/cudnn_activation.h | 58 ++++++++++++ src/model/layer/cudnn_softmax.cc | 77 ++++++++++++++++ src/model/layer/cudnn_softmax.h | 54 +++++++++++ src/model/layer/softmax.cc | 64 +++++++++++++ src/model/layer/softmax.h | 45 +++++++++ test/singa/test_activation.cc | 133 +++++++++++++++++++++++++++ test/singa/test_cudnn_activation.cc | 136 +++++++++++++++++++++++++++ test/singa/test_cudnn_dropout.cc | 2 +- test/singa/test_cudnn_softmax.cc | 107 ++++++++++++++++++++++ test/singa/test_softmax.cc | 110 ++++++++++++++++++++++ 20 files changed, 1468 insertions(+), 197 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/include/singa/core/tensor.h ---------------------------------------------------------------------- diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h index 8682bca..bb8d7f8 100644 --- a/include/singa/core/tensor.h +++ b/include/singa/core/tensor.h @@ -62,7 +62,7 @@ inline size_t SizeOf(DataType t) { /// then it must be set up correctly (shape, device). Otherwise, runtime error /// like SegmentFault would happen. Simply type/device check would be conducted. class Tensor { -public: + public: ~Tensor(); Tensor(); explicit Tensor(Shape &&shape, DataType dtype = kFloat32); @@ -83,7 +83,8 @@ public: Device *device() const { return device_; } /// Return immutable Tensor values with given type. - template <typename DType> DType data() const { + template <typename DType> + DType data() const { return static_cast<DType>(blob()->data()); } @@ -130,7 +131,8 @@ public: void ToHost(); /// Set each element of the tensor to be x - template <typename SType> void SetValue(const SType x); + template <typename SType> + void SetValue(const SType x); /// For init the tensor values, copy 'num' elements. template <typename DType> @@ -141,7 +143,7 @@ public: void CopyData(const Tensor &other); /// Return an exactly the same Tensor with data been deep copied. - Tensor Clone(); + Tensor Clone() const; // Tensor operations @@ -167,23 +169,27 @@ public: // Scalar operations. /// T is a scalar type - template <typename DType> Tensor &operator+=(DType x); + template <typename DType> + Tensor &operator+=(DType x); /// T is a scalar type - template <typename DType> Tensor &operator-=(const DType x); + template <typename DType> + Tensor &operator-=(const DType x); /// T is a scalar type - template <typename DType> Tensor &operator*=(const DType x); + template <typename DType> + Tensor &operator*=(const DType x); /// T is a scalar type - template <typename DType> Tensor &operator/=(const DType x); + template <typename DType> + Tensor &operator/=(const DType x); /// save Tensor into a proto msg // void ToProto(TensorProto* t); /// load Tensor from proto msg // void FromProto(const TensorProto& t); -protected: + protected: bool transpose_ = false; DataType data_type_ = kFloat32; Device *device_ = nullptr; @@ -220,7 +226,8 @@ Tensor Sqrt(const Tensor &t); Tensor Square(const Tensor &t); Tensor Tanh(const Tensor &t); -template <typename SType> SType Sum(const Tensor &t); +template <typename SType> +SType Sum(const Tensor &t); /// Sum elements in the Tensor, currently only support vector and matrix. /// if 'axis' is 0, sum all rows into a single row /// if 'axis' is 1, sum all columns into a single column @@ -232,16 +239,48 @@ Tensor Sum(const Tensor &t, int axis); /// if 'axis' is 1, average all columns into a single column /// TODO(wangwei) support arbitrary Tensor like numpy.average Tensor Average(const Tensor &t, int axis); +/// Regarding the internal data as 2d, with shape_[0]*...*shape_[axis-1] rows, +/// and shape_[axis]*...*shape_[nDim()] columns. +/// and do softmax along each row. +Tensor SoftMax(const Tensor &t, int axis = 0); +void SoftMax(const Tensor &t, int axis, Tensor *ret); + /// Regarding the internal data as 2d, with shape_[0]*...*shape_[axis] rows, /// and shape_[axis+1]*...*shape_[nDim()] columns. /// and do softmax along each row. -Tensor Softmax(const Tensor &t, int axis = -1); -void Softmax(const Tensor &t, Tensor *ret, int axis = -1); +// Tensor Softmax(const Tensor& t, int axis = -1); +// void Softmax(const Tensor& t, Tensor* ret, int axis = -1); + +/// Element-wise operation, ret[i]= (t[i] < x) ? 1.f : 0.f +template <typename DType> +Tensor operator<(const Tensor &t, const DType x); +template <typename DType> +void LT(const Tensor &t, DType x, Tensor *ret); + +/// Element-wise operation, ret[i]= (t[i] <= x) ? 1.f : 0.f +template <typename DType> +Tensor operator<=(const Tensor &t, const DType x); +template <typename DType> +void LE(const Tensor &t, DType x, Tensor *ret); + +/// Element-wise operation, ret[i]= (t[i] > x) ? 1.f : 0.f +template <typename DType> +Tensor operator>(const Tensor &t, const DType x); +template <typename DType> +void GT(const Tensor &t, DType x, Tensor *ret); + +/// Element-wise operation, ret[i]= (t[i] >= x) ? 1.f : 0.f +template <typename DType> +Tensor operator>=(const Tensor &t, const DType x); +template <typename DType> +void GE(const Tensor &t, DType x, Tensor *ret); /// Element-wise opeartion, ret[i]=t[i]^x -template <typename DType> Tensor Pow(const Tensor &t, DType x); +template <typename DType> +Tensor Pow(const Tensor &t, DType x); /// Element-wise opeartion, ret[i]=t[i]^x -template <typename DType> void Pow(const Tensor &t, DType x, Tensor *ret); +template <typename DType> +void Pow(const Tensor &t, DType x, Tensor *ret); /// Element-wise opeartion, ret[i]=baes[i]^exp[i] Tensor Pow(const Tensor &base, Tensor exp); /// Element-wise opeartion, ret[i]=baes[i]^exp[i] @@ -256,18 +295,25 @@ void EltwiseMult(const Tensor &lhs, const Tensor &rhs, Tensor *ret); Tensor operator/(const Tensor &lhs, const Tensor &rhs); void Div(const Tensor &lhs, const Tensor &rhs, Tensor *ret); -template <typename DType> Tensor operator+(const Tensor &t, DType x); -template <typename DType> void Add(const Tensor &t, DType x, Tensor *ret); +template <typename DType> +Tensor operator+(const Tensor &t, DType x); +template <typename DType> +void Add(const Tensor &t, DType x, Tensor *ret); -template <typename DType> Tensor operator-(const Tensor &t, DType x); -template <typename DType> void Sub(const Tensor &t, DType x, Tensor *ret); +template <typename DType> +Tensor operator-(const Tensor &t, DType x); +template <typename DType> +void Sub(const Tensor &t, DType x, Tensor *ret); -template <typename DType> Tensor operator*(const Tensor &t, DType x); +template <typename DType> +Tensor operator*(const Tensor &t, DType x); template <typename DType> void EltwiseMult(const Tensor &t, DType x, Tensor *ret); -template <typename DType> Tensor operator/(const Tensor &t, DType x); -template <typename DType> void Div(const Tensor &t, DType x, Tensor *ret); +template <typename DType> +Tensor operator/(const Tensor &t, DType x); +template <typename DType> +void Div(const Tensor &t, DType x, Tensor *ret); // ================Blas operations============================================ // We fix the scalar argument type to be float. @@ -301,6 +347,7 @@ void Uniform(float low, float high, Tensor *t); void Gaussian(float mean, float std, Tensor *t); // follow the consistency guide +// https://issues.apache.org/jira/browse/SINGA-182 // ============Matrix vector operations======================================= /// Add column 'v' with each column of matrix M void AddColumn(const Tensor &v, Tensor *M); @@ -329,12 +376,28 @@ void SumRows(const Tensor &M, Tensor *out); void SumColumns(const Tensor &M, Tensor *out); /// For each element x of Tensor 'in', compute alpha/x -template <typename SType> Tensor Div(const SType alpha, const Tensor &in); +template <typename SType> +Tensor Div(const SType alpha, const Tensor &in); /// For each element x of Tensor 'in', compute alpha/x into 'out' template <typename SType> void Div(const SType alpha, const Tensor &in, Tensor *out); +/* +/// Multiply each column of the lhs matrix with the rhs column +Tensor MultColumn(const Tensor &lhs, const Tensor &rhs); +void MultColumn(const Tensor &lhs, const Tensor &rhs, Tensor *ret); +/// Multiply each row of the lhs matrix with the rhs row +Tensor MultRow(const Tensor &lhs, const Tensor &rhs); +void MultRow(const Tensor &lhs, const Tensor &rhs, Tensor *ret); +/// Div each row of the lhs matrix with the rhs column +Tensor DivColumn(const Tensor &lhs, const Tensor &rhs); +void DivColumn(const Tensor &lhs, const Tensor &rhs, Tensor *ret); +/// Divide each row of the lhs matrix by the rhs row +Tensor DivRow(const Tensor &lhs, const Tensor &rhs); +void DivRow(const Tensor &lhs, const Tensor &rhs, Tensor *ret); +*/ + } // namespace singa #endif // SINGA_CORE_TENSOR_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/core/tensor/math_kernel.cu ---------------------------------------------------------------------- diff --git a/src/core/tensor/math_kernel.cu b/src/core/tensor/math_kernel.cu index 88041b1..aed6add 100644 --- a/src/core/tensor/math_kernel.cu +++ b/src/core/tensor/math_kernel.cu @@ -32,7 +32,7 @@ #define CU1DBLOCK 1024 #define CU1DBLOCKF 1024.0 -namespace singa{ +namespace singa { // Cuda Kernel Functions namespace cuda { __global__ void kernel_softmax_loss(const float *prob, const int *label, @@ -147,7 +147,8 @@ __global__ void kernel_add_vec_row(const float *src_vec_data, des_mat_data[index] = src_mat_data[index] + src_vec_data[i]; } } -__global__ void kernel_add(const float *src1, const float *src2, float*out, int n) { +__global__ void kernel_add(const float *src1, const float *src2, float *out, + int n) { int index = blockIdx.x * blockDim.x + threadIdx.x; int num_threads = blockDim.x * gridDim.x; for (; index < n; index += num_threads) { @@ -155,7 +156,8 @@ __global__ void kernel_add(const float *src1, const float *src2, float*out, int } } -__global__ void kernel_sub(const float *src1, const float *src2, float*out, int n) { +__global__ void kernel_sub(const float *src1, const float *src2, float *out, + int n) { int index = blockIdx.x * blockDim.x + threadIdx.x; int num_threads = blockDim.x * gridDim.x; for (; index < n; index += num_threads) { @@ -323,42 +325,28 @@ __global__ void kernel_threshold(const float *src_data, float *des_data, des_data[index] = src_data[index] < alpha ? 1.0f : 0.0f; } } - -/* -void softmaxloss_forward(int n, int dim, const float *prob, - const int *label, float *loss) { - kernel_softmax_loss<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(prob, label, loss, n, - dim); -} - -void softmaxloss_backward(int n, int dim, float scale, - const int *label, float *grad) { - kernel_softmax_gradient<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(grad, label, n, - dim, scale); -} -*/ void sum(int n, const float *in, float *out) { int threads_per_block = n > CU1DBLOCK ? CU1DBLOCK : n; // here, we only need one block int num_blocks = 1; - kernel_sum_vec<<<num_blocks, threads_per_block>>>(in, out, n); + kernel_sum_vec << <num_blocks, threads_per_block>>> (in, out, n); } void sum_row(int rows, int cols, int stride, const float *in, float *out) { int threads_per_block = rows > CU1DBLOCK ? CU1DBLOCK : rows; int num_blocks = cols; - kernel_sum_row<<<num_blocks, threads_per_block>>>(in, out, rows, cols, - stride); + kernel_sum_row << <num_blocks, threads_per_block>>> + (in, out, rows, cols, stride); } void sum_col(int rows, int cols, int stride, const float *in, float *out) { int threads_per_block = cols > CU1DBLOCK ? CU1DBLOCK : cols; int num_blocks = rows; - kernel_sum_col<<<num_blocks, threads_per_block>>>(in, out, - rows, cols, stride); + kernel_sum_col << <num_blocks, threads_per_block>>> + (in, out, rows, cols, stride); } void add_row(int rows, int cols, int stride, const float *in_row, const float *in_mat, float *out) { @@ -366,92 +354,91 @@ void add_row(int rows, int cols, int stride, const float *in_row, dim3 num_blocks( cols / threads_per_block.x + (cols % threads_per_block.x == 0 ? 0 : 1), rows / threads_per_block.y + (rows % threads_per_block.y == 0 ? 0 : 1)); - kernel_add_vec_row<<<num_blocks, threads_per_block>>>(in_row, in_mat, out, - rows, cols, stride); + kernel_add_vec_row << <num_blocks, threads_per_block>>> + (in_row, in_mat, out, rows, cols, stride); } void add(int n, const float *a, const float *b, float *out) { - kernel_add<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(a, b, out, n); + kernel_add << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n); } void sub(int n, const float *a, const float *b, float *out) { - kernel_sub<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(a, b, out, n); + kernel_sub << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n); } void exp(int n, const float *in, float *out) { - kernel_exp<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n); + kernel_exp << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); } void log(int n, const float *in, float *out) { - kernel_log<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n); + kernel_log << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); } void sigmoid(int n, const float *in, float *out) { - kernel_sigmoid<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n); + kernel_sigmoid << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); } void sigmoid_grad(int n, const float *in, float *out) { - kernel_sigmoid_grad<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n); + kernel_sigmoid_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); } void relu(int n, const float *in, float *out) { - kernel_relu<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n); + kernel_relu << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); } void relu_grad(int n, const float *in, float *out) { - kernel_relu_grad<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n); + kernel_relu_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); } void tanh(int n, const float *in, float *out) { - kernel_tanh<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n); + kernel_tanh << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); } void tanh_grad(int n, const float *in, float *out) { - kernel_tanh_grad<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n); + kernel_tanh_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); } void softplus(int n, const float *in, float *out) { - kernel_softplus<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n); + kernel_softplus << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); } void softplus_grad(int n, const float *in, float *out) { - kernel_softplus_grad<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n); + kernel_softplus_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); } void square(int n, const float *in, float *out) { - kernel_square<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n); + kernel_square << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); } void square_grad(int n, const float *in, float *out) { - kernel_square_grad<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n); + kernel_square_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); } void sqrt(int n, const float *in, float *out) { - kernel_sqrt<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, n); + kernel_sqrt << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); } void pow(int n, const float *a, const float *b, float *out) { - kernel_pow<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(a, b, out, n); + kernel_pow << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n); } void mult(int n, const float *a, const float *b, float *out) { - kernel_mult<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(a, b, out, n); + kernel_mult << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n); } void mult(int n, const float *a, const float x, float *out) { - kernel_mult<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(a, x, out, n); + kernel_mult << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, x, out, n); } void div(int n, const float *a, const float *b, float *out) { - kernel_div<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(a, b, out, n); + kernel_div << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n); } void set_value(int n, float v, float *out) { - kernel_set_value<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(out, v, n); + kernel_set_value << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (out, v, n); } void threshold(int n, float alpha, const float *in, float *out) { - kernel_threshold<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(in, out, alpha, n); + kernel_threshold << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, alpha, n); } - // follow the consistency guide for math API __global__ void KernelDiv(const size_t num, const float alpha, const float *in, float *out) { @@ -461,6 +448,36 @@ __global__ void KernelDiv(const size_t num, const float alpha, const float *in, } } +__global__ void KernelGE(const int num, const float *in, const float x, + float *out) { + for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num; + idx += blockDim.x * gridDim.x) { + out[idx] = in[idx] >= x ? 1.0f : 0.0f; + } +} +__global__ void KernelGT(const int num, const float *in, const float x, + float *out) { + for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num; + idx += blockDim.x * gridDim.x) { + out[idx] = in[idx] > x ? 1.0f : 0.0f; + } +} +__global__ void KernelLE(const int num, const float *in, const float x, + float *out) { + for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num; + idx += blockDim.x * gridDim.x) { + out[idx] = in[idx] <= x ? 1.0f : 0.0f; + } +} + +__global__ void KernelLT(const int num, const float *in, const float x, + float *out) { + for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num; + idx += blockDim.x * gridDim.x) { + out[idx] = in[idx] < x ? 1.0f : 0.0f; + } +} + __global__ void KernelSet(const size_t num, const float x, float *out) { for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num; idx += blockDim.x * gridDim.x) { @@ -468,14 +485,31 @@ __global__ void KernelSet(const size_t num, const float x, float *out) { } } +void Set(const size_t num, const float x, float *out, cudaStream_t s) { + KernelSet << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, x, out); +} void Div(const size_t num, float alpha, const float *in, float *out, cudaStream_t s) { - KernelDiv<<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>>(num, alpha, in, out); + KernelDiv << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, alpha, in, out); } -void Set(const size_t num, const float x, float *out, cudaStream_t s) { - KernelSet<<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>>(num, x, out); +void GT(const size_t num, const float *in, const float x, float *out, + cudaStream_t s) { + KernelGT << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out); +} +void GE(const size_t num, const float *in, const float x, float *out, + cudaStream_t s) { + KernelGE << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out); } +void LT(const size_t num, const float *in, const float x, float *out, + cudaStream_t s) { + KernelLT << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out); +} +void LE(const size_t num, const float *in, const float x, float *out, + cudaStream_t s) { + KernelLE << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out); +} + } // namespace cuda } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/core/tensor/math_kernel.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/math_kernel.h b/src/core/tensor/math_kernel.h index 925346e..5c906a9 100644 --- a/src/core/tensor/math_kernel.h +++ b/src/core/tensor/math_kernel.h @@ -86,7 +86,11 @@ void threshold(int n, float alpha, const float *in, float *out); void Div(const size_t num, const float x, const float *in, float *out, cudaStream_t s); void Set(const size_t num, const float x, float *out, cudaStream_t s); -} // cuda +void GT(size_t num, const float *in, const float x, float *out, cudaStream_t s); +void GE(size_t num, const float *in, const float x, float *out, cudaStream_t s); +void LT(size_t num, const float *in, const float x, float *out, cudaStream_t s); +void LE(size_t num, const float *in, const float x, float *out, cudaStream_t s); +} // cuda } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/core/tensor/tensor.cc ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index fcf42c2..5ae375c 100644 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -142,7 +142,7 @@ void Tensor::CopyData(const Tensor &src) { } } -Tensor Tensor::Clone() { +Tensor Tensor::Clone() const { Tensor t(shape_, device_, data_type_); t.transpose_ = transpose_; t.CopyData(*this); @@ -200,28 +200,28 @@ Tensor Reshape(const Tensor &in, Shape &&s) { return out; } -#define GenUnaryTensorArgMemberFunction(op, fn) \ +#define GenUnaryTensorArgMemberFn(op, fn) \ Tensor &Tensor::op(const Tensor &t) { \ fn(*this, t, this); \ return *this; \ } -GenUnaryTensorArgMemberFunction(operator+=, Add); -GenUnaryTensorArgMemberFunction(operator-=, Sub); -GenUnaryTensorArgMemberFunction(operator*=, EltwiseMult); -GenUnaryTensorArgMemberFunction(operator/=, Div); +GenUnaryTensorArgMemberFn(operator+=, Add); +GenUnaryTensorArgMemberFn(operator-=, Sub); +GenUnaryTensorArgMemberFn(operator*=, EltwiseMult); +GenUnaryTensorArgMemberFn(operator/=, Div); -#define GenUnaryScalarArgMemberFunction(op, fn) \ +#define GenUnaryScalarArgMemberFn(op, fn) \ template <typename DType> Tensor &Tensor::op(DType x) { \ fn(*this, x, this); \ return *this; \ } \ template Tensor &Tensor::op<float>(float x) -GenUnaryScalarArgMemberFunction(operator-=, Sub); -GenUnaryScalarArgMemberFunction(operator+=, Add); -GenUnaryScalarArgMemberFunction(operator*=, EltwiseMult); -GenUnaryScalarArgMemberFunction(operator/=, Div); +GenUnaryScalarArgMemberFn(operator-=, Sub); +GenUnaryScalarArgMemberFn(operator+=, Add); +GenUnaryScalarArgMemberFn(operator*=, EltwiseMult); +GenUnaryScalarArgMemberFn(operator/=, Div); // ====================Tensor Operations======================================= void CopyDataToFrom(Tensor *dst, const Tensor &src, size_t num, @@ -325,34 +325,35 @@ template <typename SType> void Tensor::SetValue(const SType x) { } template void Tensor::SetValue<float>(const float x); -#define EltwiseUnaryTensorFn(fn, t, ret) \ - do { \ - TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, { \ - ret->device()->Exec( \ - [t, ret](Context *ctx) { \ - fn<DType, Lang>(t.Size(), t.blob(), ret->blob(), ctx); \ - }, \ - {t.blob()}, {ret->blob()}); \ - }); \ +#define EltwiseUnaryTensorFn(fn, t, ret) \ + do { \ + TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, { \ + ret->device()->Exec( \ + [t, ret](Context* ctx) { \ + fn<DType, Lang>(t.Size(), t.blob(), ret->blob(), ctx); \ + }, \ + {t.blob()}, {ret->blob()}); \ + }); \ } while (0) -#define GenUnaryTensorFunction(fn) \ - Tensor fn(const Tensor &t) { \ - Tensor ret(t.shape(), t.device(), t.data_type()); \ - auto *retptr = &ret; \ - EltwiseUnaryTensorFn(fn, t, retptr); \ - return ret; \ - } - -GenUnaryTensorFunction(Abs); -GenUnaryTensorFunction(Exp); -GenUnaryTensorFunction(Log); -GenUnaryTensorFunction(ReLU); -GenUnaryTensorFunction(Sigmoid); -GenUnaryTensorFunction(Sign); -GenUnaryTensorFunction(Sqrt); -GenUnaryTensorFunction(Square); -GenUnaryTensorFunction(Tanh); +#define GenUnaryTensorFn(fn) \ + Tensor fn(const Tensor &t) { \ + Tensor ret(t.shape(), t.device(), t.data_type()); \ + auto *retptr = &ret; \ + EltwiseUnaryTensorFn(fn, t, retptr); \ + return ret; \ + } \ + void fn(const Tensor &in, Tensor *out) { EltwiseUnaryTensorFn(fn, in, out); } + +GenUnaryTensorFn(Abs); +GenUnaryTensorFn(Exp); +GenUnaryTensorFn(Log); +GenUnaryTensorFn(ReLU); +GenUnaryTensorFn(Sigmoid); +GenUnaryTensorFn(Sign); +GenUnaryTensorFn(Sqrt); +GenUnaryTensorFn(Square); +GenUnaryTensorFn(Tanh); // TODO(wangwei) conside async exec template <> float Sum<float>(const Tensor &t) { @@ -402,28 +403,25 @@ Tensor Average(const Tensor &t, int axis) { } } -Tensor Softmax(const Tensor &t, int axis) { - Tensor ret(t.shape(), t.device(), t.data_type()); - Softmax(t, &ret, axis); - return ret; +Tensor SoftMax(const Tensor &in, int axis) { + Tensor out(in.shape(), in.device(), in.data_type()); + SoftMax(in, axis, &out); + return out; } -void Softmax(const Tensor &t, Tensor *ret, int axis) { - int nrow = 1, ncol = t.Size(), size = ncol; - CHECK_GE(axis, -1); - CHECK_GT(t.shape().size(), 0u); - if (axis > -1) { - nrow = Product(t.shape(), 0, axis + 1); - CHECK_EQ(size % nrow, 0) << "Size = " << size << " nrow = " << nrow; +void SoftMax(const Tensor &in, int axis, Tensor *out) { + size_t nrow = 1, ncol = in.Size(), size = ncol; + CHECK_GE(axis, 0); + if (axis > 0) { + nrow = Product(in.shape(), 0, axis); + CHECK_EQ(size % nrow, 0u) << "Size = " << size << " nrow = " << nrow; ncol = size / nrow; } - TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, { - ret->device()->Exec( - [nrow, ncol, t, ret](Context *ctx) { - Softmax<DType, Lang>(nrow, ncol, t.blob(), ret->blob(), ctx); - }, - {t.blob()}, {ret->blob()}); - }); + Exp(in, out); + out->Reshape(Shape{nrow, ncol}); + Tensor sum(Shape{nrow}, in.device(), in.data_type()); + SumColumns(*out, &sum); + DivColumn(sum, out); } #define EltwiseBinaryTensorFn(fn, lhs, rhs, ret) \ @@ -439,7 +437,7 @@ void Softmax(const Tensor &t, Tensor *ret, int axis) { }); \ } while (0) -#define GenBinaryTensorFunction(op, fn) \ +#define GenBinaryTensorFn(op, fn) \ Tensor op(const Tensor &lhs, const Tensor &rhs) { \ Tensor ret(lhs.shape(), lhs.device(), lhs.data_type()); \ fn(lhs, rhs, &ret); \ @@ -449,11 +447,11 @@ void Softmax(const Tensor &t, Tensor *ret, int axis) { EltwiseBinaryTensorFn(fn, lhs, rhs, ret); \ } -GenBinaryTensorFunction(operator+, Add); -GenBinaryTensorFunction(operator-, Sub); -GenBinaryTensorFunction(operator*, EltwiseMult); -GenBinaryTensorFunction(operator/, Div); -GenBinaryTensorFunction(Pow, Pow); +GenBinaryTensorFn(operator+, Add); +GenBinaryTensorFn(operator-, Sub); +GenBinaryTensorFn(operator*, EltwiseMult); +GenBinaryTensorFn(operator/, Div); +GenBinaryTensorFn(Pow, Pow); #define EltwiseTensorScalarFn(fn, t, x, ret) \ do { \ @@ -468,7 +466,7 @@ GenBinaryTensorFunction(Pow, Pow); }); \ } while (0) -#define GenTensorScalarFunction(op, fn) \ +#define GenTensorScalarFn(op, fn) \ template <typename SType> Tensor op(const Tensor &t, SType x) { \ Tensor ret(t.shape(), t.device(), t.data_type()); \ fn(t, x, &ret); \ @@ -480,11 +478,15 @@ GenBinaryTensorFunction(Pow, Pow); template Tensor op<float>(const Tensor &t, float x); \ template void fn<float>(const Tensor &t, const float x, Tensor *ret) -GenTensorScalarFunction(operator+, Add); -GenTensorScalarFunction(operator-, Sub); -GenTensorScalarFunction(operator*, EltwiseMult); -GenTensorScalarFunction(operator/, Div); -GenTensorScalarFunction(Pow, Pow); +GenTensorScalarFn(operator+, Add); +GenTensorScalarFn(operator-, Sub); +GenTensorScalarFn(operator*, EltwiseMult); +GenTensorScalarFn(operator/, Div); +GenTensorScalarFn(Pow, Pow); +GenTensorScalarFn(operator<, LT); +GenTensorScalarFn(operator<=, LE); +GenTensorScalarFn(operator>, GT); +GenTensorScalarFn(operator>=, GE); // ================Blas operations============================================ Tensor Mult(const Tensor &lhs, const Tensor &rhs) { @@ -633,8 +635,8 @@ void DivRow(const Tensor &v, Tensor *M) { /// Multiply column 'v' and each column of matrix M; write results into 'out' void MultColumn(const Tensor &v, Tensor *M) { CHECK(!M->transpose()) << "Not supported yet"; - CHECK_EQ(M->nDim(), 2); - CHECK_EQ(v.nDim(), 1); + CHECK_EQ(M->nDim(), 2u); + CHECK_EQ(v.nDim(), 1u); CHECK_EQ(v.Size(), M->shape(0)); CheckDataTypeAndLang(*M, v); TYPE_LANG_SWITCH(v.data_type(), DType, v.device()->lang(), Lang, { @@ -650,8 +652,8 @@ void MultColumn(const Tensor &v, Tensor *M) { /// Multiply row 'v' with each row of matrix M; write results into 'out' void MultRow(const Tensor &v, Tensor *M) { CHECK(!M->transpose()) << "Not supported yet"; - CHECK_EQ(M->nDim(), 2); - CHECK_EQ(v.nDim(), 1); + CHECK_EQ(M->nDim(), 2u); + CHECK_EQ(v.nDim(), 1u); CHECK_EQ(v.Size(), M->shape(1)); CheckDataTypeAndLang(*M, v); TYPE_LANG_SWITCH(v.data_type(), DType, v.device()->lang(), Lang, { @@ -673,8 +675,8 @@ void SumColumns(const Tensor &M, Tensor *v) { Tensor X = M.T(); SumRows(X, v); } else { - CHECK_EQ(M.nDim(), 2); - CHECK_EQ(v->nDim(), 1); + CHECK_EQ(M.nDim(), 2u); + CHECK_EQ(v->nDim(), 1u); size_t nb_row = M.shape().at(0), nb_col = M.shape().at(1); CHECK_EQ(nb_row, v->Size()); @@ -688,8 +690,8 @@ void SumRows(const Tensor &M, Tensor *v) { Tensor X = M.T(); SumColumns(X, v); } else { - CHECK_EQ(M.nDim(), 2); - CHECK_EQ(v->nDim(), 1); + CHECK_EQ(M.nDim(), 2u); + CHECK_EQ(v->nDim(), 1u); size_t nb_row = M.shape(0), nb_col = M.shape(1); CHECK_EQ(nb_col, v->Size()); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/core/tensor/tensor_math.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h index 98d91bf..ff865e0 100644 --- a/src/core/tensor/tensor_math.h +++ b/src/core/tensor/tensor_math.h @@ -220,6 +220,27 @@ void Outer(int m, int n, const Blob *lhs, const Blob *rhs, Blob *ret, LOG(FATAL) << "Not Implemented"; } +/// ret[i]=(input[i]<x)?1.f:0.f +template <typename DType, typename Lang> +void LT(int count, const Blob *input, float x, Blob *ret, Context *ctx) { + LOG(FATAL) << "Not Implemented"; +} +/// ret[i]=(input[i]<=x)?1.f:0.f +template <typename DType, typename Lang> +void LE(int count, const Blob *input, float x, Blob *ret, Context *ctx) { + LOG(FATAL) << "Not Implemented"; +} +/// ret[i]=(input[i]>x)?1.f:0.f +template <typename DType, typename Lang> +void GT(int count, const Blob *input, float x, Blob *ret, Context *ctx) { + LOG(FATAL) << "Not Implemented"; +} +/// ret[i]=(input[i]>x)?1.f:0.f +template <typename DType, typename Lang> +void GE(int count, const Blob *input, float x, Blob *ret, Context *ctx) { + LOG(FATAL) << "Not Implemented"; +} + // ===== BLAS functions, ref to http://docs.nvidia.com/cuda/cublas // ===== Level 1 /// return the index of the element with the max value. @@ -319,6 +340,30 @@ void GEMM(const bool transA, const bool transB, const size_t nrowA, Context *ctx) { LOG(FATAL) << "Not Implemented"; } -} // namespace singa +/// ret[i]=(input[i]<x)?1.f:0.f +template <typename DType, typename Lang> +void LT(const size_t num, const Blob *in, const DType x, Blob *out, + Context *ctx) { + LOG(FATAL) << "Not Implemented"; +} +/// ret[i]=(input[i]<=x)?1.f:0.f +template <typename DType, typename Lang> +void LE(const size_t num, const Blob *in, const DType x, Blob *out, + Context *ctx) { + LOG(FATAL) << "Not Implemented"; +} +/// ret[i]=(input[i]>x)?1.f:0.f +template <typename DType, typename Lang> +void GT(const size_t num, const Blob *in, const DType x, Blob *out, + Context *ctx) { + LOG(FATAL) << "Not Implemented"; +} +/// ret[i]=(input[i]>=x)?1.f:0.f +template <typename DType, typename Lang> +void GE(const size_t num, const Blob *in, const DType x, Blob *out, + Context *ctx) { + LOG(FATAL) << "Not Implemented"; +} +} // namespace singa #endif // SINGA_CORE_MATH_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/core/tensor/tensor_math_cpp.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h index 97da896..693f09c 100644 --- a/src/core/tensor/tensor_math_cpp.h +++ b/src/core/tensor/tensor_math_cpp.h @@ -19,6 +19,7 @@ #define SINGA_CORE_TENSOR_TENSOR_MATH_CPP_H_ #include "./tensor_math.h" #include "singa/core/common.h" +#include <math.h> #ifdef USE_CBLAS #include <cblas.h> @@ -51,6 +52,16 @@ void Add<float, lang::Cpp>(int count, const Blob *lhs, const Blob *rhs, } template <> +void Add<float, lang::Cpp>(int count, const Blob *input, float x, Blob *ret, + Context *ctx) { + float *dptr = static_cast<float *>(ret->mutable_data()); + const float *lptr = static_cast<const float *>(input->data()); + for (int i = 0; i < count; i++) { + dptr[i] = lptr[i] + x; + } +} + +template <> void Sub<float, lang::Cpp>(int count, const Blob *lhs, const Blob *rhs, Blob *ret, Context *ctx) { // CHECK_EQ(ctx->stream, nullptr); @@ -61,6 +72,7 @@ void Sub<float, lang::Cpp>(int count, const Blob *lhs, const Blob *rhs, dptr[i] = lptr[i] - rptr[i]; } } + // sum all elements of input into ret // TODO(wangwei) optimize using omp template <> @@ -74,53 +86,96 @@ void Sum<float, lang::Cpp>(int count, const Blob *input, float *ret, *ret = s; } -// TODO(wangwei) optimize using omp template <> -void SumRows<float, lang::Cpp>(int nrow, int ncol, const Blob *input, Blob *ret, - Context *ctx) { +void EltwiseMult<float, lang::Cpp>(int count, const Blob *input, float x, + Blob *ret, Context *ctx) { float *dptr = static_cast<float *>(ret->mutable_data()); - const float *in = static_cast<const float *>(input->data()); - memset(dptr, 0, ncol * sizeof(float)); - for (int r = 0; r < nrow; r++) { - for (int c = 0; c < ncol; c++) { - dptr[c] += in[r * ncol + c]; - } + const float *lptr = static_cast<const float *>(input->data()); + for (int i = 0; i < count; i++) { + dptr[i] = lptr[i] * x; } } -// Sum the rows of the input matrix into a vector -// TODO(wangwei) optimize using omp template <> -void SumColumns<float, lang::Cpp>(int nrow, int ncol, const Blob *input, - Blob *ret, Context *ctx) { +void EltwiseMult<float, lang::Cpp>(int count, const Blob *lhs, const Blob *rhs, + Blob *ret, Context *ctx) { float *dptr = static_cast<float *>(ret->mutable_data()); - const float *in = static_cast<const float *>(input->data()); - memset(dptr, 0, ncol * sizeof(float)); - for (int r = 0; r < nrow; r++) { - for (int c = 0; c < ncol; c++) { - dptr[r] += in[r * ncol + c]; - } + const float *lptr = static_cast<const float *>(lhs->data()); + const float *rptr = static_cast<const float *>(rhs->data()); + for (int i = 0; i < count; i++) { + dptr[i] = lptr[i] * rptr[i]; } } template <> -void EltwiseMult<float, lang::Cpp>(int count, const Blob *input, float x, - Blob *ret, Context *ctx) { +void Exp<float, lang::Cpp>(int count, const Blob *input, Blob *ret, + Context *ctx) { float *dptr = static_cast<float *>(ret->mutable_data()); const float *lptr = static_cast<const float *>(input->data()); for (int i = 0; i < count; i++) { - dptr[i] = lptr[i] * x; + dptr[i] = exp(lptr[i]); } } template <> -void EltwiseMult<float, lang::Cpp>(int count, const Blob *lhs, const Blob *rhs, - Blob *ret, Context *ctx) { +void Log<float, lang::Cpp>(int count, const Blob *input, Blob *ret, + Context *ctx) { + float *dptr = static_cast<float *>(ret->mutable_data()); + const float *lptr = static_cast<const float *>(input->data()); + for (int i = 0; i < count; i++) { + CHECK_GT(lptr[i], 0.f); + dptr[i] = log(lptr[i]); + } +} + +template <> +void Tanh<float, lang::Cpp>(int count, const Blob *input, Blob *ret, + Context *ctx) { + float *dptr = static_cast<float *>(ret->mutable_data()); + const float *lptr = static_cast<const float *>(input->data()); + for (int i = 0; i < count; i++) { + dptr[i] = tanh(lptr[i]); + } +} + +template <> +void ReLU<float, lang::Cpp>(int count, const Blob *input, Blob *ret, + Context *ctx) { + float *dptr = static_cast<float *>(ret->mutable_data()); + const float *lptr = static_cast<const float *>(input->data()); + for (int i = 0; i < count; i++) { + dptr[i] = (lptr[i] >= 0.f) ? lptr[i] : 0.f; + } +} + +template <> +void Sigmoid<float, lang::Cpp>(int count, const Blob *input, Blob *ret, + Context *ctx) { + float *dptr = static_cast<float *>(ret->mutable_data()); + const float *lptr = static_cast<const float *>(input->data()); + for (int i = 0; i < count; i++) { + dptr[i] = 1.f / (1.f + exp(-lptr[i])); + } +} + +template <> +void Pow<float, lang::Cpp>(int count, const Blob *input, float x, Blob *ret, + Context *ctx) { + float *dptr = static_cast<float *>(ret->mutable_data()); + const float *lptr = static_cast<const float *>(input->data()); + for (int i = 0; i < count; i++) { + dptr[i] = pow(lptr[i], x); + } +} + +template <> +void Pow<float, lang::Cpp>(int count, const Blob *lhs, const Blob *rhs, + Blob *ret, Context *ctx) { float *dptr = static_cast<float *>(ret->mutable_data()); const float *lptr = static_cast<const float *>(lhs->data()); const float *rptr = static_cast<const float *>(rhs->data()); for (int i = 0; i < count; i++) { - dptr[i] = lptr[i] * rptr[i]; + dptr[i] = pow(lptr[i], rptr[i]); } } @@ -159,8 +214,15 @@ void Div<float, lang::Cpp>(const size_t num, const float alpha, const Blob *in, Blob *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); + for (size_t i = 0; i < num; i++) outPtr[i] = alpha / inPtr[i]; +} +template <> +void LT<float, lang::Cpp>(const size_t num, const Blob *in, const float x, + Blob *out, Context *ctx) { + float *outPtr = static_cast<float *>(out->mutable_data()); + const float *inPtr = static_cast<const float *>(in->data()); for (size_t i = 0; i < num; i++) { - outPtr[i] = alpha / inPtr[i]; + outPtr[i] = (inPtr[i] < x) ? 1.f : 0.f; } } @@ -192,9 +254,38 @@ template <> void Set<float, lang::Cpp>(const size_t num, const float x, Blob *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); - for (size_t i = 0; i < num; i++) - outPtr[i] = x; + for (size_t i = 0; i < num; i++) outPtr[i] = x; +} +template <> +void LE<float, lang::Cpp>(const size_t num, const Blob *in, const float x, + Blob *out, Context *ctx) { + float *outPtr = static_cast<float *>(out->mutable_data()); + const float *inPtr = static_cast<const float *>(in->data()); + for (size_t i = 0; i < num; i++) { + outPtr[i] = (inPtr[i] <= x) ? 1.f : 0.f; + } +} + +template <> +void GT<float, lang::Cpp>(const size_t num, const Blob *in, const float x, + Blob *out, Context *ctx) { + float *outPtr = static_cast<float *>(out->mutable_data()); + const float *inPtr = static_cast<const float *>(in->data()); + for (size_t i = 0; i < num; i++) { + outPtr[i] = (inPtr[i] > x) ? 1.f : 0.f; + } +} + +template <> +void GE<float, lang::Cpp>(const size_t num, const Blob *in, const float x, + Blob *out, Context *ctx) { + float *outPtr = static_cast<float *>(out->mutable_data()); + const float *inPtr = static_cast<const float *>(in->data()); + for (size_t i = 0; i < num; i++) { + outPtr[i] = (inPtr[i] >= x) ? 1.f : 0.f; + } } + #ifdef USE_CBLAS template <> void Dot<float, lang::Cpp>(const size_t num, const Blob *in1, const Blob *in2, @@ -224,7 +315,6 @@ void GEMM<float, lang::Cpp>(const bool transA, const bool transB, #endif // USE_CBLAS - } // namespace singa #endif // SINGA_CORE_TENSOR_TENSOR_MATH_CPP_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/core/tensor/tensor_math_cuda.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h index 26299ba..4a2ba66 100644 --- a/src/core/tensor/tensor_math_cuda.h +++ b/src/core/tensor/tensor_math_cuda.h @@ -73,25 +73,6 @@ void Sum<float, lang::Cuda>(int count, const Blob *input, float *ret, cuda::sum(count, in, ret); } -// TODO(wangwei) optimize using stream -template <> -void SumRows<float, lang::Cuda>(int nrow, int ncol, const Blob *input, - Blob *ret, Context *ctx) { - float *dptr = static_cast<float *>(ret->mutable_data()); - const float *in = static_cast<const float *>(input->data()); - cuda::sum_row(nrow, ncol, ncol, in, dptr); -} - -// Sum the rows of the input matrix into a vector -// TODO(wangwei) optimize using stream -template <> -void SumColumns<float, lang::Cuda>(int nrow, int ncol, const Blob *input, - Blob *ret, Context *ctx) { - float *dptr = static_cast<float *>(ret->mutable_data()); - const float *in = static_cast<const float *>(input->data()); - cuda::sum_col(nrow, ncol, ncol, in, dptr); -} - // follow the consistency guide of math API template <> void Div<float, lang::Cuda>(const size_t num, const float alpha, const Blob *in, @@ -144,7 +125,42 @@ void GEMM<float, lang::Cuda>(const bool transA, const bool transB, CUBLAS_CHECK(cublasSgemm(handle, transb, transa, ncolB, nrowA, ncolA, &alpha, BPtr, ldb, APtr, lda, &beta, CPtr, ldc)); } + +template <> +void GE<float, lang::Cuda>(const size_t num, const Blob* in, const float x, + Blob* out, Context *ctx) { + float* outPtr = static_cast<float*>(out->mutable_data()); + const float* inPtr = static_cast<const float*>(in->data()); + cuda::GE(num, inPtr, x, outPtr, ctx->stream); +} +template <> +void GT<float, lang::Cuda>(const size_t num, const Blob* in, const float x, + Blob* out, Context *ctx) { + float* outPtr = static_cast<float*>(out->mutable_data()); + const float* inPtr = static_cast<const float*>(in->data()); + cuda::GT(num, inPtr, x, outPtr, ctx->stream); +} +template <> +void LE<float, lang::Cuda>(const size_t num, const Blob* in, const float x, + Blob* out, Context *ctx) { + float* outPtr = static_cast<float*>(out->mutable_data()); + const float* inPtr = static_cast<const float*>(in->data()); + cuda::LE(num, inPtr, x, outPtr, ctx->stream); +} +template <> +void LT<float, lang::Cuda>(const size_t num, const Blob* in, const float x, + Blob* out, Context *ctx) { + float* outPtr = static_cast<float*>(out->mutable_data()); + const float* inPtr = static_cast<const float*>(in->data()); + cuda::LT(num, inPtr, x, outPtr, ctx->stream); +} + + + + + } // namespace singa #endif // USE_CUDA #endif // SINGA_CORE_TENSOR_TENSOR_MATH_CUDA_H_ + http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/model/layer/activation.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/activation.cc b/src/model/layer/activation.cc new file mode 100644 index 0000000..464e24d --- /dev/null +++ b/src/model/layer/activation.cc @@ -0,0 +1,67 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "singa/model/layer.h" +#include "./activation.h" +namespace singa { + +void Activation::Setup(const LayerConf& conf) { + Layer::Setup(conf); + mode_ = conf.type(); + if (mode_ == "RELU") { + neg_slope_ = conf.relu_conf().negative_slope(); + } +} + +const Tensor Activation::Forward(int flag, const Tensor& input) { + Tensor output; + if (mode_ == "SIGMOID") { + output = Sigmoid(input); + buf_.push(output); + } else if (mode_ == "TANH") { + output = Tanh(input); + buf_.push(output); + } else if (mode_ == "RELU") { + output = ReLU(input); + buf_.push(input); + } else { + LOG(FATAL) << "Unkown activation: " << mode_; + } + return output; +} + +const std::pair<Tensor, vector<Tensor>> Activation::Backward( + int flag, const Tensor& grad) { + vector<Tensor> param_grad; + // inout means either input or output, but only one is valid for an + // activation. + Tensor input_grad, inout = buf_.top(); + buf_.pop(); + if (mode_ == "SIGMOID") { + input_grad = grad * inout * (inout * (-1.f) + 1.f); + } else if (mode_ == "TANH") { + input_grad = grad * (inout * inout * (-1.f) + 1.f); + } else if (mode_ == "RELU") { + input_grad = grad * (inout > 0.f) + (inout <= 0.f) * neg_slope_; + } else { + LOG(FATAL) << "Unkown activation: " << mode_; + } + return std::make_pair(input_grad, param_grad); +} + +} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/model/layer/activation.h ---------------------------------------------------------------------- diff --git a/src/model/layer/activation.h b/src/model/layer/activation.h new file mode 100644 index 0000000..1747577 --- /dev/null +++ b/src/model/layer/activation.h @@ -0,0 +1,51 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SINGA_MODEL_LAYER_ACTIVATION_H_ +#define SINGA_MODEL_LAYER_ACTIVATION_H_ +#include <utility> +#include <string> +#include <vector> +#include "singa/model/layer.h" + +namespace singa { +class Activation : public Layer { + public: + /// \copydoc Layer::layer_type() + const std::string layer_type() const override { return "Activation"; } + + /// \copydoc Layer::Setup(const LayerConf&); + void Setup(const LayerConf& conf) override; + + /// \copydoc Layer::Forward(int flag, const Tensor&) + const Tensor Forward(int flag, const Tensor& input) override; + + /// \copydoc Layer::Backward(int, const Tensor&, const Tensor&); + const std::pair<Tensor, vector<Tensor>> Backward(int flag, + const Tensor& grad) override; + + const std::string Mode() const { return mode_; } + + const float Negative_slope() const { return neg_slope_; } + + protected: + std::string mode_; + std::stack<Tensor> buf_; + float neg_slope_; +}; +} // namespace singa +#endif // SINGA_MODEL_LAYER_ACTIVATION_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/model/layer/cudnn_activation.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_activation.cc b/src/model/layer/cudnn_activation.cc new file mode 100644 index 0000000..73c70d7 --- /dev/null +++ b/src/model/layer/cudnn_activation.cc @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "singa_config.h" +#ifdef USE_CUDNN +#include "./cudnn_activation.h" +#include <cudnn.h> + +#include "./cudnn_utils.h" +#include "singa/core/common.h" +#include "singa/utils/logging.h" + +namespace singa { +CudnnActivation::~CudnnActivation() { + if (acti_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyActivationDescriptor(acti_desc_)); + if (desc_ != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(desc_)); +} + +void CudnnActivation::InitCudnn(size_t size, DataType dtype) { + CHECK(!has_init_cudnn_); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&desc_)); + CUDNN_CHECK(cudnnCreateActivationDescriptor(&acti_desc_)); + + CUDNN_CHECK(cudnnSetTensor4dDescriptor( + desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), 1, 1, 1, size)); + + if (mode_ == "SIGMOID") + cudnn_mode_ = CUDNN_ACTIVATION_SIGMOID; + else if (mode_ == "TANH") + cudnn_mode_ = CUDNN_ACTIVATION_TANH; + else if (mode_ == "RELU") + cudnn_mode_ = CUDNN_ACTIVATION_RELU; + else + LOG(FATAL) << "Unkown activation: " << mode_; + + nan_opt_ = CUDNN_PROPAGATE_NAN; + CUDNN_CHECK( + cudnnSetActivationDescriptor(acti_desc_, cudnn_mode_, nan_opt_, 0.0f)); + has_init_cudnn_ = true; +} + +const Tensor CudnnActivation::Forward(int flag, const Tensor& input) { + auto size = input.Size(); + DataType dtype = input.data_type(); + if (!has_init_cudnn_) { + InitCudnn(size, dtype); + } + Tensor output; + output.ResetLike(input); + output.device()->Exec([input, output, this](Context* ctx) { + Blob* inblob = input.blob(), * outblob = output.blob(); + float alpha = 1.0f, beta = 0.0f; +#if CUDNN_VERSION_MAJOR == 5 + CUDNN_CHECK(cudnnActivationForward( + ctx->cudnn_handle, this->acti_desc_, &alpha, this->desc_, + inblob->data(), &beta, this->desc_, outblob->mutable_data())); +#elif CUDNN_VERSION_MAJOR == 4 + CUDNN_CHECK(cudnnActivationForward_v4( + ctx->cudnn_handle, this->acti_desc_, &alpha, this->desc_, + inblob->data(), &beta, this->desc_, outblob->mutable_data())); +#endif + }, {input.blob()}, {output.blob()}); + if (cudnn_mode_ == CUDNN_ACTIVATION_SIGMOID || + cudnn_mode_ == CUDNN_ACTIVATION_TANH) { + buf_.push(output); + } else if (cudnn_mode_ == CUDNN_ACTIVATION_RELU) { + buf_.push(input); + } + return output; +} + +const std::pair<Tensor, vector<Tensor>> CudnnActivation::Backward( + int flag, const Tensor& grad) { + vector<Tensor> param_grad; + Tensor dx; // inout = buf_.top(); + // inout means either used as input or output, only one is valid for one type + // of activation + Tensor inout = buf_.top(); + buf_.pop(); + dx.ResetLike(grad); + dx.device()->Exec([dx, grad, inout, this](Context* ctx) { + Blob* dyblob = grad.blob(), * dxblob = dx.blob(), * yblob = inout.blob(), + * xblob = inout.blob(); + float alpha = 1.0f, beta = 0.0f; +#if CUDNN_VERSION_MAJOR == 5 + CUDNN_CHECK(cudnnActivationBackward( + ctx->cudnn_handle, this->acti_desc_, &alpha, this->desc_, yblob->data(), + this->desc_, dyblob->data(), this->desc_, xblob->data(), &beta, + this->desc_, dxblob->mutable_data())); +#elif CUDNN_VERSION_MAJOR == 4 + CUDNN_CHECK(cudnnActivationBackward_v4( + ctx->cudnn_handle, this->acti_desc_, &alpha, this->desc_, yblob->data(), + this->desc_, dyblob->data(), this->desc_, xblob->data(), &beta, + this->desc_, dxblob->mutable_data())); +#endif + }, {grad.blob(), inout.blob()}, {dx.blob()}); + return std::make_pair(dx, param_grad); +} +} // namespace singa +#endif // USE_CUDNN http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/model/layer/cudnn_activation.h ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_activation.h b/src/model/layer/cudnn_activation.h new file mode 100644 index 0000000..b572db7 --- /dev/null +++ b/src/model/layer/cudnn_activation.h @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SINGA_MODEL_LAYER_CUDNN_ACTIVATION_H_ +#define SINGA_MODEL_LAYER_CUDNN_ACTIVATION_H_ +#include "singa_config.h" +#ifdef USE_CUDNN +#include <cudnn.h> +#include <utility> +#include <string> +#include <vector> + +#include "./activation.h" +#include "singa/core/common.h" +#include "singa/model/layer.h" +#include "singa/proto/core.pb.h" + +namespace singa { +class CudnnActivation : public Activation { + public: + ~CudnnActivation(); + /// \copydoc Layer::layer_type() + const std::string layer_type() const override { return "CudnnActivation"; } + + const Tensor Forward(int flag, const Tensor& input) override; + const std::pair<Tensor, vector<Tensor>> Backward(int flag, + const Tensor& grad) override; + + /// Init cudnn related data structures. + void InitCudnn(size_t size, DataType dtype); + + const cudnnActivationMode_t CudnnMode() const { return cudnn_mode_; } + + private: + bool has_init_cudnn_ = false; + cudnnActivationDescriptor_t acti_desc_; + cudnnTensorDescriptor_t desc_; + cudnnNanPropagation_t nan_opt_; + cudnnActivationMode_t cudnn_mode_; +}; +} // namespace +#endif // USE_CUDNN +#endif // SINGA_MODEL_LAYER_CUDNN_ACTIVATION_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/model/layer/cudnn_softmax.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_softmax.cc b/src/model/layer/cudnn_softmax.cc new file mode 100644 index 0000000..bc7fe78 --- /dev/null +++ b/src/model/layer/cudnn_softmax.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "singa_config.h" +#include "./cudnn_softmax.h" +#ifdef USE_CUDNN +#include <cudnn.h> +#include "./cudnn_utils.h" +#include "singa/utils/logging.h" +namespace singa { +CudnnSoftmax::~CudnnSoftmax() { + if (desc_ != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(desc_)); +} + +void CudnnSoftmax::InitCudnn(size_t size, DataType dtype) { + CHECK(!has_init_cudnn_); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&desc_)); + + CUDNN_CHECK(cudnnSetTensor4dDescriptor( + desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), 1, 1, 1, size)); + + algorithm_ = CUDNN_SOFTMAX_ACCURATE; + mode_ = CUDNN_SOFTMAX_MODE_INSTANCE; + has_init_cudnn_ = true; +} + +const Tensor CudnnSoftmax::Forward(int flag, const Tensor& input) { + auto size = input.Size(); + DataType dtype = input.data_type(); + if (!has_init_cudnn_) { + InitCudnn(size, dtype); + } + Tensor output; + output.ResetLike(input); + output.device()->Exec([input, output, this](Context* ctx) { + Blob* inblob = input.blob(), * outblob = output.blob(); + float alpha = 1.0f, beta = 0.0f; + cudnnSoftmaxForward(ctx->cudnn_handle, this->algorithm_, this->mode_, + &alpha, this->desc_, inblob->data(), &beta, this->desc_, + outblob->mutable_data()); + }, {input.blob()}, {output.blob()}); + buf_.push(output); + return output; +} + +const std::pair<Tensor, vector<Tensor>> CudnnSoftmax::Backward( + int flag, const Tensor& grad) { + vector<Tensor> param_grad; + Tensor dx, output = buf_.top(); + buf_.pop(); + dx.ResetLike(grad); + dx.device()->Exec([dx, grad, output, this](Context* ctx) { + Blob* dyblob = grad.blob(), * dxblob = dx.blob(), * yblob = output.blob(); + float alpha = 1.0f, beta = 0.0f; + cudnnSoftmaxBackward(ctx->cudnn_handle, this->algorithm_, this->mode_, + &alpha, this->desc_, yblob->data(), this->desc_, + dyblob->data(), &beta, this->desc_, + dxblob->mutable_data()); + }, {grad.blob(), output.blob()}, {dx.blob()}); + return std::make_pair(dx, param_grad); +} +} // namespace singa +#endif // USE_CUDNN http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/model/layer/cudnn_softmax.h ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_softmax.h b/src/model/layer/cudnn_softmax.h new file mode 100644 index 0000000..ee92d6f --- /dev/null +++ b/src/model/layer/cudnn_softmax.h @@ -0,0 +1,54 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SINGA_MODEL_LAYER_CUDNN_SOFTMAX_H_ +#define SINGA_MODEL_LAYER_CUDNN_SOFTMAX_H_ +#ifdef USE_CUDNN +#include <cudnn.h> +#include <utility> +#include <string> +#include <vector> + +#include "./softmax.h" +#include "singa/core/common.h" +#include "singa/model/layer.h" +#include "singa/proto/core.pb.h" + +namespace singa { +class CudnnSoftmax : public Softmax { + public: + ~CudnnSoftmax(); + /// \copydoc Layer::layer_type() + const std::string layer_type() const override { return "CudnnSoftmax"; } + + const Tensor Forward(int flag, const Tensor& input) override; + const std::pair<Tensor, vector<Tensor>> Backward(int flag, + const Tensor& grad) override; + + /// Init cudnn related data structures. + void InitCudnn(size_t size, DataType dtype); + + private: + bool has_init_cudnn_ = false; + cudnnTensorDescriptor_t desc_; + cudnnSoftmaxAlgorithm_t algorithm_; + cudnnSoftmaxMode_t mode_; +}; +} // namespace +#endif // USE_CUDNN +#endif // SINGA_MODEL_LAYER_CUDNN_SOFTMAX_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/model/layer/softmax.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/softmax.cc b/src/model/layer/softmax.cc new file mode 100644 index 0000000..813ebf0 --- /dev/null +++ b/src/model/layer/softmax.cc @@ -0,0 +1,64 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "./softmax.h" +namespace singa { + +void Softmax::Setup(const LayerConf& conf) { + Layer::Setup(conf); + axis_ = conf.softmax_conf().axis(); // default is 1 +} + +const Tensor Softmax::Forward(int flag, const Tensor& input) { + if (input.nDim() == 1) { + Tensor tmp = Reshape(input, Shape{1, input.Size()}); + buf_.push(SoftMax(tmp, 0)); + } else { + buf_.push(SoftMax(input, axis_)); + } + return buf_.top(); +} + +const std::pair<Tensor, vector<Tensor>> Softmax::Backward(int flag, + const Tensor& grad) { + size_t nrow = 1, ncol = grad.Size(); + if (grad.nDim() > 1 && axis_ > 0) { + nrow = Product(grad.shape(), 0, axis_); + ncol = Product(grad.shape(), axis_, grad.nDim()); + } + Tensor input_grad = grad.Clone(); + input_grad.Reshape(Shape{nrow, ncol}); + Tensor y = buf_.top(); + buf_.pop(); + CHECK(y.shape() == input_grad.shape()); + Tensor sigma = input_grad * y; + Tensor sum(Shape{nrow}, grad.device(), grad.data_type()); + SumColumns(sigma, &sum); + // dL / dy_i = grad_i + // dy_i / dx_i = y_i - y_i^2, if i == j + // dy_i / dx_j = - y_i * y_j, if i != j + // dL / dx_i = sum_j((dL / dy_j) * (dy_j / dx_i)) + // dL / dx_i = y_i * (grad_i - sum), where sum = sum_i(grad_i * y_i); + SubColumn(sum, &input_grad); + input_grad = input_grad * y; + // Mult(input_grad, y, &input_grad); + vector<Tensor> param_grad; + return std::make_pair(input_grad, param_grad); +} + +} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/src/model/layer/softmax.h ---------------------------------------------------------------------- diff --git a/src/model/layer/softmax.h b/src/model/layer/softmax.h new file mode 100644 index 0000000..ea3a70a --- /dev/null +++ b/src/model/layer/softmax.h @@ -0,0 +1,45 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SINGA_MODEL_LAYER_SOFTMAX_H_ +#define SINGA_MODEL_LAYER_SOFTMAX_H_ +#include "singa/model/layer.h" +#include <stack> +namespace singa { +class Softmax : public Layer { + public: + /// \copydoc Layer::layer_type() + const std::string layer_type() const override { return "Softmax"; } + + /// \copydoc Layer::Setup(const LayerConf&); + void Setup(const LayerConf& conf) override; + + /// \copydoc Layer::Forward(int flag, const Tensor&) + const Tensor Forward(int flag, const Tensor& input) override; + + /// \copydoc Layer::Backward(int flag, const Tensor&, const Tensor&); + const std::pair<Tensor, vector<Tensor>> Backward(int flag, + const Tensor& grad) override; + + const int Axis() const { return axis_; } + + protected: + int axis_; + std::stack<Tensor> buf_; +}; +} // namespace singa +#endif // SINGA_MODEL_LAYER_SOFTMAX_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/test/singa/test_activation.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_activation.cc b/test/singa/test_activation.cc new file mode 100644 index 0000000..9e34282 --- /dev/null +++ b/test/singa/test_activation.cc @@ -0,0 +1,133 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#include "../src/model/layer/activation.h" +#include "gtest/gtest.h" +#include <math.h> // exp, tanh + +using singa::Activation; +TEST(Activation, Setup) { + Activation acti; + EXPECT_EQ("Activation", acti.layer_type()); + + singa::LayerConf conf; + conf.set_type("RELU"); + singa::ReLUConf* reluconf = conf.mutable_relu_conf(); + reluconf->set_negative_slope(0.5); + + acti.Setup(conf); + EXPECT_EQ("RELU", acti.Mode()); + EXPECT_EQ(0.5f, acti.Negative_slope()); +} + +TEST(Activation, Forward) { + const float x[] = {1.0f, 2.0f, 3.0f, -2.0f, -3.0f, -4.0}; + size_t n = sizeof(x) / sizeof(float); + singa::Tensor in(singa::Shape{n}); + in.CopyDataFromHostPtr<float>(x, n); + + float neg_slope = 0.5f; + std::string types[] = {"SIGMOID","TANH","RELU"}; + for (int j = 0; j < 3; j++) { + Activation acti; + singa::LayerConf conf; + std::string layertype = types[j]; + conf.set_type(layertype); + if (layertype == "RELU") { + singa::ReLUConf* reluconf = conf.mutable_relu_conf(); + reluconf->set_negative_slope(neg_slope); + } + acti.Setup(conf); + + singa::Tensor out = acti.Forward(0, in); + + const float* yptr = out.data<const float*>(); + EXPECT_EQ(n, out.Size()); + + float* y = new float[n]; + if (acti.Mode() == "SIGMOID") { + for (size_t i = 0; i < n; i++) + y[i] = 1.f / (1.f + exp(-x[i])); + } + else if (acti.Mode() == "TANH") { + for (size_t i = 0; i < n; i++) + y[i] = tanh(x[i]); + } + else if (acti.Mode() == "RELU") { + for (size_t i = 0; i < n; i++) + y[i] = (x[i] >= 0.f) ? x[i] : 0.f; + } + else + LOG(FATAL) << "Unkown activation: " << acti.Mode(); + EXPECT_FLOAT_EQ(y[0], yptr[0]); + EXPECT_FLOAT_EQ(y[4], yptr[4]); + EXPECT_FLOAT_EQ(y[5], yptr[5]); + } +} + +TEST(Activation, Backward) { + const float x[] = {1.0f, 2.0f, 3.0f, -2.0f, -3.0f, -4.0}; + size_t n = sizeof(x) / sizeof(float); + singa::Tensor in(singa::Shape{n}); + in.CopyDataFromHostPtr<float>(x, n); + + float neg_slope = 0.5f; + std::string types[] = {"SIGMOID","TANH","RELU"}; + for (int j = 0; j < 3; j++) { + Activation acti; + singa::LayerConf conf; + std::string layertype = types[j]; + conf.set_type(layertype); + if (layertype == "RELU") { + singa::ReLUConf* reluconf = conf.mutable_relu_conf(); + reluconf->set_negative_slope(neg_slope); + } + acti.Setup(conf); + + singa::Tensor out = acti.Forward(0, in); + const float* yptr = out.data<const float*>(); + + const float grad[] = {2.0f, -3.0f, 1.0f, 3.0f, -1.0f, -2.0}; + singa::Tensor out_diff(singa::Shape{n}); + out_diff.CopyDataFromHostPtr<float>(grad, n); + const auto in_diff = acti.Backward(0, out_diff); + const float* xptr = in_diff.first.data<const float*>(); + + float* dx = new float[n]; + if (acti.Mode() == "SIGMOID") { + for (size_t i = 0; i < n; i++) + dx[i] = grad[i] * yptr[i] * (1. - yptr[i]); + } + else if (acti.Mode() == "TANH") { + for (size_t i = 0; i < n; i++) + dx[i] = grad[i] * (1 - yptr[i] * yptr[i]); + } + else if (acti.Mode() == "RELU") { + for (size_t i = 0; i < n; i++) + dx[i] = grad[i] * (x[i] > 0.f) + acti.Negative_slope() * (x[i] <= 0.f); + } + else + LOG(FATAL) << "Unkown activation: " << acti.Mode(); + EXPECT_FLOAT_EQ(dx[0], xptr[0]); + EXPECT_FLOAT_EQ(dx[4], xptr[4]); + EXPECT_FLOAT_EQ(dx[5], xptr[5]); + } +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/test/singa/test_cudnn_activation.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_cudnn_activation.cc b/test/singa/test_cudnn_activation.cc new file mode 100644 index 0000000..ee9f9b5 --- /dev/null +++ b/test/singa/test_cudnn_activation.cc @@ -0,0 +1,136 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ +#include "singa_config.h" +#ifdef USE_CUDNN + +#include "singa/proto/core.pb.h" +#include "../src/model/layer/cudnn_activation.h" +#include "gtest/gtest.h" +#include <math.h> // exp tanh +#include <cudnn.h> + +using singa::CudnnActivation; +TEST(TCudnnActivation, Setup) { + CudnnActivation acti; + EXPECT_EQ("CudnnActivation", acti.layer_type()); + + singa::LayerConf conf; + conf.set_type("RELU"); + singa::ReLUConf* reluconf = conf.mutable_relu_conf(); + reluconf->set_negative_slope(0.5f); + + acti.Setup(conf); + acti.InitCudnn(1, singa::kFloat32); + EXPECT_EQ(CUDNN_ACTIVATION_RELU, acti.CudnnMode()); + EXPECT_EQ(0.5f, acti.Negative_slope()); +} + +TEST(TCudnnActivation, Forward) { + const float x[] = {1.0f, 2.0f, 3.0f, -2.0f, -3.0f, -4.0}; + size_t n = sizeof(x) / sizeof(float); + singa::CudaGPU cuda(0, 1); + singa::Tensor in(singa::Shape{n}, &cuda); + in.CopyDataFromHostPtr<float>(x, n); + + float neg_slope = 0.5f; + std::string types[] = {"SIGMOID", "TANH", "RELU"}; + for (int j = 0; j < 3; j++) { + CudnnActivation acti; + singa::LayerConf conf; + std::string layertype = types[j]; + conf.set_type(layertype); + if (layertype == "RELU") { + singa::ReLUConf* reluconf = conf.mutable_relu_conf(); + reluconf->set_negative_slope(neg_slope); + } + acti.Setup(conf); + // acti.InitCudnn(n, singa::kFloat32); + + singa::Tensor out = acti.Forward(0, in); + EXPECT_EQ(n, out.Size()); + singa::CppCPU host(0, 1); + out.ToDevice(&host); + const float* yptr = out.data<const float*>(); + float* y = new float[n]; + if (acti.Mode() == "SIGMOID") { + for (size_t i = 0; i < n; i++) y[i] = 1.f / (1.f + exp(-x[i])); + } else if (acti.Mode() == "TANH") { + for (size_t i = 0; i < n; i++) y[i] = tanh(x[i]); + } else if (acti.Mode() == "RELU") { + for (size_t i = 0; i < n; i++) y[i] = (x[i] >= 0.f) ? x[i] : 0.f; + } else + LOG(FATAL) << "Unkown activation: " << acti.Mode(); + EXPECT_FLOAT_EQ(y[0], yptr[0]); + EXPECT_FLOAT_EQ(y[4], yptr[4]); + EXPECT_FLOAT_EQ(y[5], yptr[5]); + } +} + +TEST(TCudnnActivation, Backward) { + const float x[] = {2.0f, 3.0f, 3.0f, 7.f, 0.0f, 5.0, 1.5, 2.5, -2.5, 1.5}; + size_t n = sizeof(x) / sizeof(float); + singa::CudaGPU cuda(0, 1); + singa::Tensor in(singa::Shape{n}, &cuda); + in.CopyDataFromHostPtr<float>(x, n); + float neg_slope = 0.5f; + std::string types[] = {"SIGMOID", "TANH", "RELU"}; + for (int j = 0; j < 3; j++) { + CudnnActivation acti; + singa::LayerConf conf; + std::string layertype = types[j]; + conf.set_type(layertype); + if (layertype == "RELU") { + singa::ReLUConf* reluconf = conf.mutable_relu_conf(); + reluconf->set_negative_slope(neg_slope); + } + acti.Setup(conf); + acti.InitCudnn(n, singa::kFloat32); + singa::Tensor out = acti.Forward(0, in); + EXPECT_EQ(n, out.Size()); + singa::CppCPU host(0, 1); + out.ToDevice(&host); + const float* yptr = out.data<const float*>(); + + const float grad[] = {2.0f, 1.0f, 2.0f, 0.0f, -2.0f, + -1.0, 1.5, 2.5, -1.5, -2.5}; + singa::Tensor out_diff(singa::Shape{n}, &cuda); + out_diff.CopyDataFromHostPtr<float>(grad, n); + const auto ret = acti.Backward(0, out_diff); + singa::Tensor in_diff = ret.first; + in_diff.ToDevice(&host); + const float* xptr = in_diff.data<const float*>(); + float* dx = new float[n]; + if (acti.Mode() == "SIGMOID") { + for (size_t i = 0; i < n; i++) dx[i] = grad[i] * yptr[i] * (1. - yptr[i]); + } else if (acti.Mode() == "TANH") { + for (size_t i = 0; i < n; i++) dx[i] = grad[i] * (1. - yptr[i] * yptr[i]); + } else if (acti.Mode() == "RELU") { + for (size_t i = 0; i < n; i++) + dx[i] = + grad[i] * (x[i] > 0.f); //+ acti.Negative_slope() * (x[i] <= 0.f); + } else + LOG(FATAL) << "Unkown activation: " << acti.Mode(); + for (size_t i = 0; i < n; i++) { + EXPECT_NEAR(dx[i], xptr[i], 1e-7); + } + } +} +#endif // USE_CUDNN http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/test/singa/test_cudnn_dropout.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_cudnn_dropout.cc b/test/singa/test_cudnn_dropout.cc index e1a6333..32572d0 100644 --- a/test/singa/test_cudnn_dropout.cc +++ b/test/singa/test_cudnn_dropout.cc @@ -21,7 +21,7 @@ #include "../src/model/layer/cudnn_dropout.h" #ifdef USE_CUDNN // cudnn dropout is added in cudnn 5 -#if CUDNN_MAJOR_VERSION >= 5 +#if CUDNN_VERSION_MAJOR >= 5 #include "gtest/gtest.h" http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/test/singa/test_cudnn_softmax.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_cudnn_softmax.cc b/test/singa/test_cudnn_softmax.cc new file mode 100644 index 0000000..dcbf1ed --- /dev/null +++ b/test/singa/test_cudnn_softmax.cc @@ -0,0 +1,107 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ +#include "singa_config.h" +#ifdef USE_CUDNN + +#include "../src/model/layer/cudnn_softmax.h" +#include "gtest/gtest.h" +#include <math.h> // exp +#include <cudnn.h> + +using singa::CudnnSoftmax; +TEST(CudnnSoftmax, Setup) { + CudnnSoftmax sft; + EXPECT_EQ("CudnnSoftmax", sft.layer_type()); + + singa::LayerConf conf; + singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf(); + softmaxconf->set_axis(2); + + sft.Setup(conf); + sft.InitCudnn(1, singa::kFloat32); + EXPECT_EQ(2, sft.Axis()); +} + +TEST(CudnnSoftmax, Forward) { + const float x[] = {1.0f, 2.0f, 0.0f, -2.0f, -3.0f, -1.0}; + size_t n = sizeof(x) / sizeof(float); + singa::CudaGPU cuda(0, 1); + singa::Tensor in(singa::Shape{n}, &cuda); + in.CopyDataFromHostPtr<float>(x, n); + + int axis = 1; + CudnnSoftmax sft; + singa::LayerConf conf; + singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf(); + softmaxconf->set_axis(axis); + sft.Setup(conf); + sft.InitCudnn(n, singa::kFloat32); + + singa::Tensor out = sft.Forward(0, in); + singa::CppCPU host(0, 1); + out.ToDevice(&host); + const float* yptr = out.data<const float*>(); + EXPECT_EQ(n, out.Size()); + + float* y = new float[n]; + float sigma = 0.f; + for (size_t i = 0; i < n; i++) sigma += exp(x[i]); + for (size_t i = 0; i < n; i++) y[i] = exp(x[i]) / sigma; + EXPECT_FLOAT_EQ(y[0], yptr[0]); + EXPECT_FLOAT_EQ(y[4], yptr[4]); + EXPECT_FLOAT_EQ(y[5], yptr[5]); +} + +TEST(CudnnSoftmax, Backward) { + const float x[] = {1.0f, 2.0f, 3.0f, -2.0f, -3.0f, -1.0}; + size_t n = sizeof(x) / sizeof(float); + singa::CudaGPU cuda(0, 1); + singa::Tensor in(singa::Shape{n}, &cuda); + in.CopyDataFromHostPtr<float>(x, n); + + int axis = 1; + CudnnSoftmax sft; + singa::LayerConf conf; + singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf(); + softmaxconf->set_axis(axis); + sft.Setup(conf); + singa::Tensor out = sft.Forward(0, in); + singa::CppCPU host(0, 1); + out.ToDevice(&host); + const float* yptr = out.data<const float*>(); + + const float grad[] = {2.0f, -3.0f, 1.0f, 3.0f, -1.0f, -2.0}; + singa::Tensor out_diff(singa::Shape{n}, &cuda); + out_diff.CopyDataFromHostPtr<float>(grad, n); + const auto ret = sft.Backward(0, out_diff); + singa::Tensor in_diff = ret.first; + in_diff.ToDevice(&host); + const float* xptr = in_diff.data<const float*>(); + + float* dx = new float[n]; + float sigma = 0.f; + for (size_t i = 0; i < n; i++) sigma += grad[i] * yptr[i]; + for (size_t i = 0; i < n; i++) dx[i] = (grad[i] - sigma) * yptr[i]; + EXPECT_FLOAT_EQ(dx[0], xptr[0]); + EXPECT_FLOAT_EQ(dx[4], xptr[4]); + EXPECT_FLOAT_EQ(dx[5], xptr[5]); +} +#endif // USE_CUDNN http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2507b7/test/singa/test_softmax.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_softmax.cc b/test/singa/test_softmax.cc new file mode 100644 index 0000000..da2a6ef --- /dev/null +++ b/test/singa/test_softmax.cc @@ -0,0 +1,110 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#include "../src/model/layer/softmax.h" +#include "gtest/gtest.h" +#include <math.h> // exp + +using singa::Softmax; +TEST(Softmax, Setup) { + Softmax sft; + EXPECT_EQ("Softmax", sft.layer_type()); + + singa::LayerConf conf; + singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf(); + softmaxconf->set_axis(2); + + sft.Setup(conf); + EXPECT_EQ(2, sft.Axis()); +} + +TEST(Softmax, Forward) { + const float x[] = {1.0f, 2.0f, 0.0f, -2.0f, -3.0f, -1.0}; + size_t n = sizeof(x) / sizeof(float); + size_t row = 2; + size_t col = 3; + singa::Tensor in(singa::Shape{row, col}); + in.CopyDataFromHostPtr<float>(x, n); + + int axis = 1; + Softmax sft; + singa::LayerConf conf; + singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf(); + softmaxconf->set_axis(axis); + sft.Setup(conf); + + singa::Tensor out = sft.Forward(0, in); + const float* yptr = out.data<const float*>(); + EXPECT_EQ(n, out.Size()); + + float* y = new float[n]; + float* sigma = new float[row]; + for (size_t i = 0; i < row; i++) + sigma[i] = 0.f; + for (size_t i = 0; i < n; i++) + sigma[i / col] += exp(x[i]); + //EXPECT_EQ(0, sigma[1]); + for (size_t i = 0; i < row; i++) + for (size_t j = 0; j < col; j++) + y[i * col + j] = exp(x[i * col + j]) / sigma[i]; + EXPECT_FLOAT_EQ(y[0], yptr[0]); + EXPECT_FLOAT_EQ(y[4], yptr[4]); + EXPECT_FLOAT_EQ(y[5], yptr[5]); +} + +TEST(Softmax, Backward) { + const float x[] = {1.0f, 2.0f, 0.0f, -2.0f, -3.0f, -1.0}; + size_t n = sizeof(x) / sizeof(float); + size_t row = 2; + size_t col = 3; + singa::Tensor in(singa::Shape{row, col}); + in.CopyDataFromHostPtr<float>(x, n); + + int axis = 1; + Softmax sft; + singa::LayerConf conf; + singa::SoftmaxConf* softmaxconf = conf.mutable_softmax_conf(); + softmaxconf->set_axis(axis); + sft.Setup(conf); + singa::Tensor out = sft.Forward(0, in); + const float* yptr = out.data<const float*>(); + + const float grad[] = {2.0f, -3.0f, 1.0f, 3.0f, -1.0f, -2.0}; + singa::Tensor out_diff(singa::Shape{row, col}); + out_diff.CopyDataFromHostPtr<float>(grad, n); + const auto in_diff = sft.Backward(0, out_diff); + const float* xptr = in_diff.first.data<const float*>(); + + float* dx = new float[n]; + float* sigma = new float[row]; + for (size_t i = 0; i < row; i++) + sigma[i] = 0.f; + for (size_t i = 0; i < n; i++) + sigma[i / col] += grad[i] * yptr[i]; + // EXPECT_EQ(0, sigma[0]); + // EXPECT_EQ(0, sigma[1]); + for (size_t i = 0; i < row; i++) + for (size_t j = 0; j < col; j++) + dx[i * col + j] = (grad[i * col + j] - sigma[i]) * yptr[i * col +j]; + EXPECT_FLOAT_EQ(dx[0], xptr[0]); + EXPECT_FLOAT_EQ(dx[4], xptr[4]); + EXPECT_FLOAT_EQ(dx[5], xptr[5]); +}
