SINGA-80 New Blob Level and Address Level Math Operation Interface
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/4728f7ce Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/4728f7ce Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/4728f7ce Branch: refs/heads/master Commit: 4728f7ce4fb4e1e019690600550c039748de50a7 Parents: c13e037 Author: seaok <seao...@gmail.com> Authored: Tue Nov 3 10:45:17 2015 +0800 Committer: Wei Wang <wang...@comp.nus.edu.sg> Committed: Mon Nov 9 17:04:48 2015 +0800 ---------------------------------------------------------------------- include/singa/blob/math_addr.h | 34 +++- include/singa/blob/math_blob.h | 86 ++++---- include/singa/blob/math_kernel.h | 49 ++++- include/singa/blob/singa_op.h | 92 ++++++++- src/blob/math_blob.cc | 4 +- src/blob/math_kernel.cu | 371 +++++++++++++++++++++++++++++++++- src/test/test_math.cc | 128 +++++++++++- 7 files changed, 708 insertions(+), 56 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4728f7ce/include/singa/blob/math_addr.h ---------------------------------------------------------------------- diff --git a/include/singa/blob/math_addr.h b/include/singa/blob/math_addr.h index a6663ab..4895343 100644 --- a/include/singa/blob/math_addr.h +++ b/include/singa/blob/math_addr.h @@ -79,18 +79,42 @@ template<typename Op> void cpu_expand_f(const float * A,const int m, const int n void gpu_gemm(const float * A, const float * B, const int m, const int n, const int k, const float alpha, const float beta, const bool TranA, const bool TranB, float * C); void gpu_gemv(const float * A, const float * B, const int m, const int n, const float alpha, const float beta, const bool TranA, float * C); void gpu_axpy(const float * A, const int n, const float alpha, float * B); +float gpu_dot(const float * A, const float * B, const int n); //element-wise -template<typename Op> void gpu_e_f(const int n, const float alpha, float * A); -template<typename Op> void gpu_e_f(const int n,const float * A,const float alpha, const float beta,float * B); -template<typename Op> void gpu_e_f(const int n,const float * A,const float * B,const float alpha, const float beta,float * C); +template<typename Op> void gpu_e_f(const int n, const float alpha, float * A) +{ + Op::CudaMap(alpha, A, n); +} + +template<typename Op> void gpu_e_f(const int n,const float * A,const float alpha, float * B) +{ + Op::CudaMap(alpha, A, B, n); +} + +template<typename Op> void gpu_e_f(const int n,const float * A,const float * B,const float alpha, const float beta,float * C) +{ + Op::CudaMap(alpha, beta, A, B, C, n); +} // element-wise generalized operation defined in Op //matrix/vector expand/reduce -template<typename Op> void gpu_reduce_f(const float * A,const int m, const int n, float * B); +template<typename Op> void gpu_reduce_f(const float * A,const int m, const int n, float * B) +{ + for(int i = 0 ; i < m ; i++) + { + Op::CudaMap(A+i*n, n, B[i]); + } +} //reduce each row of A to an element of B e.g. the sum operation in softmax -template<typename Op> void gpu_expand_f(const float * A,const int m, const int n, float * B); +template<typename Op> void gpu_expand_f(const float * A,const int m, const int n, float * B) +{ + for(int i = 0 ; i < m ; i++) + { + Op::CudaMap(A[i], n, B+i*n); + } +} //expand each element in A into a row of B http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4728f7ce/include/singa/blob/math_blob.h ---------------------------------------------------------------------- diff --git a/include/singa/blob/math_blob.h b/include/singa/blob/math_blob.h index d5991a7..ee0fb60 100644 --- a/include/singa/blob/math_blob.h +++ b/include/singa/blob/math_blob.h @@ -207,46 +207,52 @@ void E_Func(XPU xpu, Blob<float> * A, float alpha) if(xpu == gpu) { //gpu part + int n = get_size(A->shape()); + gpu_e_f<Op>(n, alpha, A->mutable_gpu_data()); } } template<typename Op> void E_Func(XPU xpu, const Blob<float> & A, Blob<float> * B, float alpha) { - if(xpu == cpu) + if(check_shape_equal(A, *B, *B)) { - if(check_shape_equal(A, *B, *B)) - { int n = get_size(A.shape()); - cpu_e_f<Op>(n, A.cpu_data(), alpha, B->mutable_cpu_data()); - } - else{ - // report errors here - } - } - if(xpu == gpu) - { - //gpu part + if(xpu == cpu) + { + cpu_e_f<Op>(n, A.cpu_data(), alpha, B->mutable_cpu_data()); + } + + if(xpu == gpu) + { + //gpu part + gpu_e_f<Op>(n, A.gpu_data(), alpha, B->mutable_gpu_data()); + } } + else{ + // report errors here + } } template<typename Op> void E_Func(XPU xpu, const Blob<float> & A, const Blob<float> & B, Blob<float> * C, float alpha, float beta) { - if(xpu == cpu) + if(check_shape_equal(A, B, *C)) { - if(check_shape_equal(A, B, *C)) + int n = get_size(A.shape()); + + if(xpu == cpu) { - int n = get_size(A.shape()); cpu_e_f<Op>(n, A.cpu_data(), B.cpu_data(), alpha, beta, C->mutable_cpu_data()); } - else{ - // report errors here + if(xpu == gpu) + { + //gpu part + gpu_e_f<Op>(n, A.gpu_data(), B.gpu_data(), alpha, beta, C->mutable_gpu_data()); } } - if(xpu == gpu) - { - //gpu part + else{ + // report errors here } } @@ -394,21 +400,23 @@ void Bernoulli(XPU xpu, Blob & A, float p, int n = 1); template<typename Op> void Reduce_F(XPU xpu, const Blob<float> & A, Blob<float> * B) { - if(xpu == cpu) + if(check_shape_mv(A, *B)) { - if(check_shape_mv(A, *B)) + int m = get_size(B->shape()); + int n = get_size(A.shape()) / m; + + if(xpu == cpu) { - int m = get_size(B->shape()); - int n = get_size(A.shape()) / m; cpu_reduce_f<Op>(A.cpu_data(), m, n, B->mutable_cpu_data()); } - else{ - // report errors here + if(xpu == gpu) + { + //gpu part + gpu_reduce_f<Op>(A.gpu_data(), m, n, B->mutable_gpu_data()); } } - if(xpu == gpu) - { - //gpu part + else{ + // report errors here } } //reduce each row of A to an element of B e.g. the sum operation in softmax @@ -416,21 +424,23 @@ void Reduce_F(XPU xpu, const Blob<float> & A, Blob<float> * B) template<typename Op> void Expand_F(XPU xpu, const Blob<float> & A, Blob<float> * B) { - if(xpu == cpu) + if(check_shape_mv(*B, A)) { - if(check_shape_mv(*B, A)) + int m = get_size(A.shape()); + int n = get_size(B->shape()) / m; + + if(xpu == cpu) { - int m = get_size(A.shape()); - int n = get_size(B->shape()) / m; cpu_expand_f<Op>(A.cpu_data(), m, n, B->mutable_cpu_data()); } - else{ - // report errors here + if(xpu == gpu) + { + //gpu part + gpu_expand_f<Op>(A.gpu_data(), m, n, B->mutable_gpu_data()); } } - if(xpu == gpu) - { - //gpu part + else{ + // report errors here } } //expand each element in A into a row of B http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4728f7ce/include/singa/blob/math_kernel.h ---------------------------------------------------------------------- diff --git a/include/singa/blob/math_kernel.h b/include/singa/blob/math_kernel.h index 9aaf4c2..f5d3e34 100644 --- a/include/singa/blob/math_kernel.h +++ b/include/singa/blob/math_kernel.h @@ -4,9 +4,54 @@ namespace singa{ extern "C" { - void singa_sum_col(float *src_mat_data, float *dst_vec_data, long rows, long cols, long stride); + void singa_gpu_sum_vec(float *data, float *sum , long n); + + void singa_gpu_sum_col(const float *src_mat_data, float *dst_vec_data, long rows, long cols, long stride); + + void singa_gpu_add_vec_row(const float *src_vec_data, const float *src_mat_data, float *des_mat_data, long rows, long cols, long stride); + + void singa_gpu_set_value(float *data, float value, long n); + + void singa_gpu_scale(const float *src_data, float *des_data, float alpha, long n); + + void singa_gpu_scale_grad(float *data, float alpha, long n); + + void singa_gpu_exp(const float *src_data, float *des_data, float alpha, long n); + + void singa_gpu_exp_grad(const float *src_data, float *des_data, float alpha, long n); + + void singa_gpu_sigmoid(const float *src_data, float *des_data, float alpha, long n); + + void singa_gpu_sigmoid_grad(const float *src_data, float *des_data, float alpha, long n); + + void singa_gpu_relu(const float *src_data, float *des_data, float alpha, long n); + + void singa_gpu_relu_grad(const float *src_data, float *des_data, float alpha, long n); + + void singa_gpu_tanh(const float *src_data, float *des_data, float alpha, long n); + + void singa_gpu_tanh_grad(const float *src_data, float *des_data, float alpha, long n); + + void singa_gpu_softplus(const float *src_data, float *des_data, float alpha, long n); + + void singa_gpu_softplus_grad(const float *src_data, float *des_data, float alpha, long n); + + void singa_gpu_square(const float *src_data, float *des_data, float alpha, long n); + + void singa_gpu_square_grad(const float *src_data, float *des_data, float alpha, long n); + + void singa_gpu_sqrt(const float *src_data, float *des_data, float alpha, long n); + + void singa_gpu_threshold(const float *src_data, float *des_data, float alpha, long n); + + void singa_gpu_add(const float *src_data_a, const float *src_data_b, float *des_data, float alpha, float beta, long n); + + void singa_gpu_sub(const float *src_data_a, const float *src_data_b, float *des_data, float alpha, float beta, long n); + + void singa_gpu_mult(const float *src_data_a, const float *src_data_b, float *des_data, float alpha, float beta, long n); + + void singa_gpu_div(const float *src_data_a, const float *src_data_b, float *des_data, float alpha, float beta, long n); - void singa_add_vec_row(float *src_vec_data, float *src_mat_data, float *des_mat_data, long rows, long cols, long stride); }; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4728f7ce/include/singa/blob/singa_op.h ---------------------------------------------------------------------- diff --git a/include/singa/blob/singa_op.h b/include/singa/blob/singa_op.h index b36c001..33ef4f8 100644 --- a/include/singa/blob/singa_op.h +++ b/include/singa/blob/singa_op.h @@ -3,6 +3,9 @@ #include<cmath> #include <algorithm> +#include <cuda_runtime.h> +#include "cublas_v2.h" +#include "singa/blob/math_kernel.h" namespace singa { enum XPU { cpu, gpu, any}; @@ -14,28 +17,45 @@ namespace singa_op { inline static void Map(float alpha, float & a) { a= alpha; } + inline static void CudaMap(float alpha, float * a, int n) { + singa::singa_gpu_set_value(a, alpha, n); + } }; struct Scale { inline static void Map(float alpha, const float & a, float & b) { - b = alpha*a; + b = alpha* a; + } + inline static void CudaMap(float alpha, const float * a, float * b, int n) { + singa::singa_gpu_scale(a,b,alpha,n); } }; + struct Scale_grad { inline static void Map(float alpha, float & output) { output = alpha; } + inline static void CudaMap(float alpha, float *output, int n) { + singa::singa_gpu_scale_grad(output,alpha,n); + } }; struct Exp { inline static void Map(float alpha, const float & a, float & b) { b = pow(a, alpha); } + inline static void CudaMap(float alpha, const float * a, float * b, int n) { + singa::singa_gpu_exp(a,b,alpha,n); + } }; + struct Exp_grad { inline static void Map(float alpha, const float & a, float & b) { - b = a * log(alpha); // log is the natrual log based on e + b = a * log(alpha); + } + inline static void CudaMap(float alpha, const float * a, float * b, int n) { + singa::singa_gpu_exp_grad(a,b,alpha,n); } }; @@ -43,92 +63,144 @@ namespace singa_op { inline static void Map(float alpha, const float & a, float & b) { b = 1.0f / (1.0f + expf(-a * alpha)); } + inline static void CudaMap(float alpha, const float * a, float * b, int n) { + singa::singa_gpu_sigmoid(a,b,alpha,n); + } }; + struct Gsigmoid_grad { inline static void Map(float alpha, const float & a, float & b) { b = alpha * a * ( 1.0f - a ); } + inline static void CudaMap(float alpha, const float * a, float * b, int n) { + singa::singa_gpu_sigmoid_grad(a,b,alpha,n); + } }; struct Grelu { inline static void Map(float alpha, const float & a, float & b) { b = ( 1 - alpha ) * std::max( a, 0.0f ) + alpha * a; } + inline static void CudaMap(float alpha, const float * a, float * b, int n) { + singa::singa_gpu_relu(a,b,alpha,n); + } }; + struct Grelu_grad { inline static void Map(float alpha, const float & a, float & b) { b = a > 0.0f ? 1.0f : alpha; } + inline static void CudaMap(float alpha, const float * a, float * b, int n) { + singa::singa_gpu_relu_grad(a,b,alpha,n); + } }; struct Gtanh { inline static void Map(float alpha, const float & a, float & b) { b = tanhf( a * alpha ); } + inline static void CudaMap(float alpha, const float * a, float * b, int n) { + singa::singa_gpu_tanh(a,b,alpha,n); + } }; + struct Gtanh_grad { inline static void Map(float alpha, const float & a, float & b) { b = alpha * ( 1.0f - a * a ); } + inline static void CudaMap(float alpha, const float * a, float * b, int n) { + singa::singa_gpu_tanh_grad(a,b,alpha,n); + } }; struct Softplus { inline static void Map(float alpha, const float & a, float & b) { b = logf(1 + expf(a)); } + inline static void CudaMap(float alpha, const float * a, float * b, int n) { + singa::singa_gpu_softplus(a,b,alpha,n); + } }; + struct Softplus_grad { inline static void Map(float alpha, const float & a, float & b) { b = 1.0f / (1.0f + expf(-a)); } + inline static void CudaMap(float alpha, const float * a, float * b, int n) { + singa::singa_gpu_softplus_grad(a,b,alpha,n); + } }; struct Square { inline static void Map(float alpha, const float & a, float & b) { b = a * a; } + inline static void CudaMap(float alpha, const float * a, float * b, int n) { + singa::singa_gpu_square(a,b,alpha,n); + } }; struct Square_grad { inline static void Map(float alpha, const float & a, float & b) { b = 2 * sqrt(a); } + inline static void CudaMap(float alpha, const float * a, float * b, int n) { + singa::singa_gpu_square_grad(a,b,alpha,n); + } }; struct Sqrt { inline static void Map(float alpha, const float & a, float & b) { b = sqrt(a); } + inline static void CudaMap(float alpha, const float * a, float * b, int n) { + singa::singa_gpu_sqrt(a,b,alpha,n); + } }; struct Threshold { inline static void Map(float alpha, const float & a, float & b) { b = a < alpha ? 1.0f : 0.0f; } + inline static void CudaMap(float alpha, const float * a, float * b, int n) { + singa::singa_gpu_threshold(a,b,alpha,n); + } }; struct Add { inline static void Map(float alpha, float beta, const float & a, const float & b, float & c) { c = a + b; } + inline static void CudaMap(float alpha, float beta, const float * a, const float * b, float *c, int n) { + singa::singa_gpu_add(a,b,c,alpha,beta,n); + } }; struct Sub { inline static void Map(float alpha, float beta, const float & a, const float & b, float & c) { c = a - b; } + inline static void CudaMap(float alpha, float beta, const float * a, const float * b, float *c, int n) { + singa::singa_gpu_sub(a,b,c,alpha,beta,n); + } }; struct Mult { inline static void Map(float alpha, float beta, const float & a, const float & b, float & c) { c = a * b; } + inline static void CudaMap(float alpha, float beta, const float * a, const float * b, float *c, int n) { + singa::singa_gpu_mult(a,b,c,alpha,beta,n); + } }; struct Div { inline static void Map(float alpha, float beta, const float & a, const float & b, float & c) { c = a / b; } + inline static void CudaMap(float alpha, float beta, const float * a, const float * b, float *c, int n) { + singa::singa_gpu_div(a,b,c,alpha,beta,n); + } }; struct Sum { @@ -139,6 +211,16 @@ namespace singa_op { b += a[i]; } } + + inline static void CudaMap(const float * a, int n, float & b) { + float *sum = NULL; + cudaMalloc((void**)&sum, n*sizeof(float)); + + singa::singa_gpu_sum_vec(a,sum,n); + + cudaMemcpyAsync(&b, sum, sizeof(float), cudaMemcpyDeviceToDevice); + cudaFree(sum); + } }; struct Expand_Div { @@ -148,6 +230,9 @@ namespace singa_op { b[i] /= a; } } + inline static void CudaMap(const float & a, int n, float * b) { + singa::singa_gpu_scale(b,b,a,n); + } }; struct Repmat { @@ -157,6 +242,9 @@ namespace singa_op { b[i] = a; } } + inline static void CudaMap(const float & a, int n, float * b) { + singa::singa_gpu_set_value(b,a,n); + } }; }; // namespace op http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4728f7ce/src/blob/math_blob.cc ---------------------------------------------------------------------- diff --git a/src/blob/math_blob.cc b/src/blob/math_blob.cc index 9421367..bd0e6ee 100644 --- a/src/blob/math_blob.cc +++ b/src/blob/math_blob.cc @@ -166,7 +166,7 @@ void MVAdd(XPU xpu, const Blob<float> & A, Blob<float> * B, float alpha, float b if(xpu == gpu) { - singa_add_vec_row(B->gpu_data(),A.gpu_data(),A.gpu_data(),m,n,n); + singa_gpu_add_vec_row(B->gpu_data(),A.gpu_data(),A.gpu_data(),m,n,n); //gpu part } } @@ -192,7 +192,7 @@ void MVSum(XPU xpu, const Blob<float> & A, Blob<float> * B, float alpha, float b } if(xpu == gpu) { - singa_sum_col(A.gpu_data(),B->gpu_data(),m,n,n); + singa_gpu_sum_col(A.gpu_data(),B->gpu_data(),m,n,n); //gpu part } } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4728f7ce/src/blob/math_kernel.cu ---------------------------------------------------------------------- diff --git a/src/blob/math_kernel.cu b/src/blob/math_kernel.cu index 6b2a709..4c828d5 100644 --- a/src/blob/math_kernel.cu +++ b/src/blob/math_kernel.cu @@ -1,3 +1,4 @@ +#include <cmath> #include "singa/blob/math_kernel.h" #define CU2DBLOCK_X 32 @@ -8,8 +9,42 @@ //Cuda Kernel Functions + __global__ -void kernel_sum_col(float *src_mat_data, float *dst_vec_data, long rows, long cols, long stride) +void kernel_sum_vec(float *data, float *sum , long n) +{ + int THREADS = blockDim.x; + + __shared__ float aux[CU1DBLOCK]; + int steps = (n - 1) / THREADS + 1; + aux[threadIdx.x] = data[threadIdx.x]; + + for(int i=1; i<steps; ++i) { + if(threadIdx.x+i*THREADS < n) { + aux[threadIdx.x] += data[threadIdx.x+i*THREADS]; + } + } + + int total_threads = THREADS; + __syncthreads(); + + while(total_threads > 1) { + int half_point = ((1+total_threads) >> 1); + if (threadIdx.x < half_point) { + if(threadIdx.x+half_point < total_threads) { + aux[threadIdx.x] += aux[threadIdx.x + half_point]; + } + } + __syncthreads(); + total_threads = ((total_threads+1) >> 1); + } + + __syncthreads(); + *sum = aux[0]; +} + +__global__ +void kernel_sum_col(const float *src_mat_data, float *dst_vec_data, long rows, long cols, long stride) { int j = blockIdx.x; int THREADS = blockDim.x; @@ -44,7 +79,7 @@ void kernel_sum_col(float *src_mat_data, float *dst_vec_data, long rows, long co } __global__ -void kernel_add_vec_row(float *src_vec_data, float *src_mat_data, float* des_mat_data,long rows, long cols, long stride) +void kernel_add_vec_row(const float *src_vec_data, const float *src_mat_data, float* des_mat_data,long rows, long cols, long stride) { long i = blockIdx.x * blockDim.x + threadIdx.x; long j = blockIdx.y * blockDim.y + threadIdx.y; @@ -57,10 +92,230 @@ void kernel_add_vec_row(float *src_vec_data, float *src_mat_data, float* des_mat } } +__global__ static +void kernel_set_value(float *data, float value, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + data[index] = value; + } +} + +__global__ +void kernel_scale(const float *src_data, float *des_data, float alpha, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = src_data[index] * alpha; + } +} + +__global__ +void kernel_scale_grad(float *data, float alpha, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + data[index] = alpha; + } +} + +__global__ +void kernel_exp(const float *src_data, float *des_data, float alpha, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = pow(-src_data[index],alpha); + } +} + +__global__ +void kernel_exp_grad(const float *src_data, float *des_data, float alpha, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = src_data[index] * log(alpha); + } +} + +__global__ +void kernel_sigmoid(const float *src_data, float *des_data, float alpha, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = 1.0f / (1.0f + expf(-src_data[index]) * alpha); + } +} + +__global__ +void kernel_sigmoid_grad(const float *src_data, float *des_data, float alpha, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = src_data[index] * (1.0f - src_data[index]) * alpha; + } +} + +__global__ +void kernel_relu(const float *src_data, float *des_data, float alpha, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = 1.0f / ( 1 - alpha ) * max( src_data[index], 0.0f ) + alpha * src_data[index]; + } +} + +__global__ +void kernel_relu_grad(const float *src_data, float *des_data, float alpha, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = src_data[index] > 0.0f ? 1.0f : alpha; + } +} + + +__global__ +void kernel_tanh(const float *src_data, float *des_data, float alpha, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = tanhf( src_data[index] * alpha ); + } +} + +__global__ +void kernel_tanh_grad(const float *src_data, float *des_data, float alpha, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = alpha * (1.0f - src_data[index] * src_data[index] ); + } +} + +__global__ +void kernel_softplus(const float *src_data, float *des_data, float alpha, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = logf(1 + expf(src_data[index])); + } +} + +__global__ +void kernel_softplus_grad(const float *src_data, float *des_data, float alpha, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = 1.0f / (1.0f + expf(-src_data[index])); + } +} + +__global__ +void kernel_square(const float *src_data, float *des_data, float alpha, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = src_data[index] * src_data[index]; + } +} + +__global__ +void kernel_square_grad(const float *src_data, float *des_data, float alpha, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = 2 * sqrt(src_data[index]); + } +} + +__global__ +void kernel_sqrt(const float *src_data, float *des_data, float alpha, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = sqrt(src_data[index]); + } +} + +__global__ +void kernel_threshold(const float *src_data, float *des_data, float alpha, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = src_data[index] < alpha ? 1.0f : 0.0f; + } +} + +__global__ +void kernel_add(const float *src_data_a, const float *src_data_b, float *des_data, float alpha, float beta, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = src_data_a[index] + src_data_b[index]; + } +} + +__global__ +void kernel_sub(const float *src_data_a, const float *src_data_b, float *des_data, float alpha, float beta, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = src_data_a[index] - src_data_b[index]; + } +} + +__global__ +void kernel_mult(const float *src_data_a, const float *src_data_b, float *des_data, float alpha, float beta, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = src_data_a[index] * src_data_b[index]; + } +} + +__global__ +void kernel_div(const float *src_data_a, const float *src_data_b, float *des_data, float alpha, float beta, long n) +{ + long index = blockIdx.x * blockDim.x + threadIdx.x; + long num_threads = blockDim.x * gridDim.x; + for(; index<n; index+=num_threads) { + des_data[index] = src_data_a[index] / src_data_b[index]; + } +} + // namespace singa{ -void singa_sum_col(float *src_mat_data, float *dst_vec_data, long rows, long cols, long stride) +void singa_gpu_sum_vec(float *data, float *sum , long n) +{ + long threads_per_block = n > CU1DBLOCK ? CU1DBLOCK : n; + // here, we only need one block + long num_blocks = 1; + + kernel_sum_vec<<<num_blocks, threads_per_block>>>(data, sum, n); +} + +void singa_gpu_sum_col(const float *src_mat_data, float *dst_vec_data, long rows, long cols, long stride) { long threads_per_block = rows > CU1DBLOCK ? CU1DBLOCK : rows; long num_blocks = cols; @@ -68,11 +323,117 @@ void singa_sum_col(float *src_mat_data, float *dst_vec_data, long rows, long col kernel_sum_col<<<num_blocks, threads_per_block>>>(src_mat_data, dst_vec_data, rows, cols, stride); } -void singa_add_vec_row(float *src_vec_data, float *src_mat_data, float *des_mat_data ,long rows, long cols, long stride) +void singa_gpu_add_vec_row(const float *src_vec_data, const float *src_mat_data, float *des_mat_data ,long rows, long cols, long stride) { dim3 threads_per_block(CU2DBLOCK_X, CU2DBLOCK_Y); dim3 num_blocks(cols/threads_per_block.x + (cols%threads_per_block.x == 0 ? 0 : 1), rows/threads_per_block.y + (rows%threads_per_block.y == 0 ? 0 : 1)); kernel_add_vec_row<<<num_blocks, threads_per_block>>>(src_vec_data, src_mat_data, des_mat_data,rows, cols, stride); } -}//namespace singa +void singa_gpu_set_value(float *data, float value, long n) +{ + kernel_set_value<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(data, value, n); +} + +void singa_gpu_scale(const float *src_data, float *des_data, float alpha, long n) +{ + kernel_scale<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data, des_data, alpha, n); +} + +void singa_gpu_scale_grad(float *data, float alpha, long n) +{ + kernel_scale_grad<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(data, alpha, n); +} + +void singa_gpu_exp(const float *src_data, float *des_data, float alpha, long n) +{ + kernel_exp<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data, des_data, alpha, n); +} + +void singa_gpu_exp_grad(const float *src_data, float *des_data, float alpha, long n) +{ + kernel_exp_grad<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data, des_data, alpha, n); +} + +void singa_gpu_sigmoid(const float *src_data, float *des_data, float alpha, long n) +{ + kernel_sigmoid<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data, des_data, alpha, n); +} + +void singa_gpu_sigmoid_grad(const float *src_data, float *des_data, float alpha, long n) +{ + kernel_sigmoid_grad<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data, des_data, alpha, n); +} + +void singa_gpu_relu(const float *src_data, float *des_data, float alpha, long n) +{ + kernel_relu<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data, des_data, alpha, n); +} + +void singa_gpu_relu_grad(const float *src_data, float *des_data, float alpha, long n) +{ + kernel_relu_grad<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data, des_data, alpha, n); +} + +void singa_gpu_tanh(const float *src_data, float *des_data, float alpha, long n) +{ + kernel_tanh<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data, des_data, alpha, n); +} + +void singa_gpu_tanh_grad(const float *src_data, float *des_data, float alpha, long n) +{ + kernel_tanh_grad<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data, des_data, alpha, n); +} + +void singa_gpu_softplus(const float *src_data, float *des_data, float alpha, long n) +{ + kernel_softplus<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data, des_data, alpha, n); +} + +void singa_gpu_softplus_grad(const float *src_data, float *des_data, float alpha, long n) +{ + kernel_softplus_grad<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data, des_data, alpha, n); +} + +void singa_gpu_square(const float *src_data, float *des_data, float alpha, long n) +{ + kernel_square<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data, des_data, alpha, n); +} + +void singa_gpu_square_grad(const float *src_data, float *des_data, float alpha, long n) +{ + kernel_square_grad<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data, des_data, alpha, n); +} + +void singa_gpu_sqrt(const float *src_data, float *des_data, float alpha, long n) +{ + kernel_sqrt<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data, des_data, alpha, n); +} + +void singa_gpu_threshold(const float *src_data, float *des_data, float alpha, long n) +{ + kernel_threshold<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data, des_data, alpha, n); +} + +void singa_gpu_add(const float *src_data_a, const float *src_data_b, float *des_data, float alpha, float beta, long n) +{ + kernel_add<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data_a, src_data_b, des_data, alpha, beta, n); +} + +void singa_gpu_sub(const float *src_data_a, const float *src_data_b, float *des_data, float alpha, float beta, long n) +{ + kernel_sub<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data_a, src_data_b, des_data, alpha, beta, n); +} + +void singa_gpu_mult(const float *src_data_a, const float *src_data_b, float *des_data, float alpha, float beta, long n) +{ + kernel_mult<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data_a, src_data_b, des_data, alpha, beta, n); +} + +void singa_gpu_div(const float *src_data_a, const float *src_data_b, float *des_data, float alpha, float beta, long n) +{ + kernel_div<<<ceil(n/CU1DBLOCKF), CU1DBLOCKF>>>(src_data_a, src_data_b, des_data, alpha, beta, n); +} + + +}//namespace singa_gpu http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/4728f7ce/src/test/test_math.cc ---------------------------------------------------------------------- diff --git a/src/test/test_math.cc b/src/test/test_math.cc index 3856d1d..a8a9490 100644 --- a/src/test/test_math.cc +++ b/src/test/test_math.cc @@ -100,6 +100,42 @@ TEST(MathTest, TestAxpyCPU) { } } +TEST(MathTest, TestEopCPU) { + + float A[10] = {}; + float B[10] = {}; + float C[10] = {}; + float D[10] = {}; + float O[10] = {}; + + for(int i = 0; i < 10; i++) + { + A[i] = i; + B[i] = -i; + C[i] = i; + + } + + cpu_e_f<singa_op::Set>(5, 15, O); + for(int i = 0; i < 5; i++) + { + ASSERT_EQ(O[i]-15,0); + } + for(int i = 5; i < 10; i++) + { + ASSERT_EQ(O[i],0); + } + cpu_e_f<singa_op::Scale>(10, C, 2, C); + for(int i = 0; i < 10; i++) + { + ASSERT_EQ(C[i]-2*i,0); + } + cpu_e_f<singa_op::Add>(10, A, B, 0, 0, O); + for(int i = 0; i < 10; i++) + { + ASSERT_EQ(O[i],0); + } +} TEST(MathTest, TestGemmGPU) { float A[3][2] = {}; @@ -314,7 +350,7 @@ TEST(MathTest, TestSingaSumColGPU) { cudaMalloc((void**)&B_gpu, 4*sizeof(float)); cudaMemcpy(A_gpu,A,12*sizeof(float),cudaMemcpyHostToDevice); - singa_sum_col(A_gpu,B_gpu,3,4,4); + singa_gpu_sum_col(A_gpu,B_gpu,3,4,4); cudaMemcpy(B,B_gpu,4*sizeof(float),cudaMemcpyDeviceToHost); @@ -367,7 +403,7 @@ TEST(MathTest, TestSingaAddVecRowGPU) { cudaMemcpy(A_gpu,A,3*4*sizeof(float),cudaMemcpyHostToDevice); cudaMemcpy(B_gpu,B,4*sizeof(float),cudaMemcpyHostToDevice); - singa_add_vec_row(B_gpu,A_gpu,C_gpu,3,4,4); + singa_gpu_add_vec_row(B_gpu,A_gpu,C_gpu,3,4,4); cudaMemcpy(C,C_gpu,3*4*sizeof(float),cudaMemcpyDeviceToHost); @@ -383,3 +419,91 @@ TEST(MathTest, TestSingaAddVecRowGPU) { cudaFree(B_gpu); cudaFree(C_gpu); } + + +TEST(MathTest, TestSingaSetValueGPU) { + + float A[3][4]; + + float* A_gpu=NULL; + float* B_gpu=NULL; + + cudaMalloc((void**)&A_gpu, 3*4*sizeof(float)); + + cudaMemcpy(A_gpu,A,3*4*sizeof(float),cudaMemcpyHostToDevice); + + singa_gpu_set_value(A_gpu,4.0,3*4); + + cudaMemcpy(A,A_gpu,3*4*sizeof(float),cudaMemcpyDeviceToHost); + + for(int i = 0; i < 3; i++) + { + for(int j = 0; j < 4; j++) + { + ASSERT_EQ(A[i][j],4.0f); + } + } + + cudaFree(A_gpu); +} + + +TEST(MathTest, TestEopGPU) { + + float A[10] = {}; + float B[10] = {}; + float C[10] = {}; + float D[10] = {}; + float O[10] = {}; + + for(int i = 0; i < 10; i++) + { + A[i] = i; + B[i] = -i; + C[i] = i; + O[i] = 0.0f; + + } + + float* A_gpu=NULL; + float* B_gpu=NULL; + float* C_gpu=NULL; + float* O_gpu=NULL; + + cudaMalloc((void**)&A_gpu, 10*sizeof(float)); + cudaMalloc((void**)&B_gpu, 10*sizeof(float)); + cudaMalloc((void**)&C_gpu, 10*sizeof(float)); + cudaMalloc((void**)&O_gpu, 10*sizeof(float)); + + cudaMemcpy(A_gpu,A,10*sizeof(float),cudaMemcpyHostToDevice); + cudaMemcpy(B_gpu,B,10*sizeof(float),cudaMemcpyHostToDevice); + cudaMemcpy(C_gpu,C,10*sizeof(float),cudaMemcpyHostToDevice); + cudaMemcpy(O_gpu,O,10*sizeof(float),cudaMemcpyHostToDevice); + + gpu_e_f<singa_op::Set>(5, 15, O_gpu); + cudaMemcpy(O,O_gpu,10*sizeof(float),cudaMemcpyDeviceToHost); + + for(int i = 0; i < 5; i++) + { + ASSERT_EQ(O[i]-15,0); + } + for(int i = 5; i < 10; i++) + { + ASSERT_EQ(O[i],0); + } + gpu_e_f<singa_op::Scale>(10, C_gpu, 2, C_gpu); + cudaMemcpy(C,C_gpu,10*sizeof(float),cudaMemcpyDeviceToHost); + + for(int i = 0; i < 10; i++) + { + ASSERT_EQ(C[i]-2*i,0); + } + + gpu_e_f<singa_op::Add>(10, A_gpu, B_gpu, 0, 0, O_gpu); + cudaMemcpy(O,O_gpu,10*sizeof(float),cudaMemcpyDeviceToHost); + + for(int i = 0; i < 10; i++) + { + ASSERT_EQ(O[i],0); + } +}