SINGA-192 Implement optimization algorithms for v1 Merge branch PR#164 into dev
Fix the bugs in test adagrad and rmsprop. Note, expect near (with diff 1e-5) is used to avoid numeric bugs. Need to do test on more machines. Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/5784bff3 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/5784bff3 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/5784bff3 Branch: refs/heads/dev Commit: 5784bff3e5ebfb3a992624d10f03f30cd5e520a3 Parents: 6d69047 178db01 Author: Wei Wang <[email protected]> Authored: Sun Jun 12 15:43:53 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Sun Jun 12 18:03:12 2016 +0800 ---------------------------------------------------------------------- include/singa/model/optimizer.h | 43 ++++++++++++++ src/core/tensor/math_kernel.cu | 14 ++--- src/core/tensor/math_kernel.h | 2 +- src/core/tensor/tensor.cc | 3 +- src/model/optimizer/adagrad.cc | 36 ++++++++++++ src/model/optimizer/nesterov.cc | 43 ++++++++++++++ src/model/optimizer/rmsprop.cc | 41 ++++++++++++++ src/proto/model.proto | 3 + test/singa/test_adagrad.cc | 96 +++++++++++++++++++++++++++++++ test/singa/test_nesterov.cc | 101 +++++++++++++++++++++++++++++++++ test/singa/test_rmsprop.cc | 106 +++++++++++++++++++++++++++++++++++ 11 files changed, 478 insertions(+), 10 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5784bff3/src/core/tensor/math_kernel.cu ---------------------------------------------------------------------- diff --cc src/core/tensor/math_kernel.cu index b618f9b,aed6add..484868a --- a/src/core/tensor/math_kernel.cu +++ b/src/core/tensor/math_kernel.cu @@@ -236,192 -300,151 +236,192 @@@ __global__ void KernelThreshold(const s } } - __global__ void KernelGE(const int num, const float *in, const float x, -__global__ void kernel_div(const float *src_data_a, const float *src_data_b, - float *des_data, int n) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int num_threads = blockDim.x * gridDim.x; - for (; index < n; index += num_threads) { - des_data[index] = src_data_a[index] / src_data_b[index]; ++__global__ void KernelGE(const size_t num, const float *in, const float x, + float *out) { + for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num; + idx += blockDim.x * gridDim.x) { + out[idx] = in[idx] >= x ? 1.0f : 0.0f; } } - __global__ void KernelGT(const int num, const float *in, const float x, - -__global__ static void kernel_set_value(float *data, float value, int n) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int num_threads = blockDim.x * gridDim.x; - for (; index < n; index += num_threads) { - data[index] = value; ++__global__ void KernelGT(const size_t num, const float *in, const float x, + float *out) { + for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num; + idx += blockDim.x * gridDim.x) { + out[idx] = in[idx] > x ? 1.0f : 0.0f; } } - __global__ void KernelLE(const int num, const float *in, const float x, - -__global__ void kernel_threshold(const float *src_data, float *des_data, - float alpha, int n) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int num_threads = blockDim.x * gridDim.x; - for (; index < n; index += num_threads) { - des_data[index] = src_data[index] < alpha ? 1.0f : 0.0f; ++__global__ void KernelLE(const size_t num, const float *in, const float x, + float *out) { + for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num; + idx += blockDim.x * gridDim.x) { + out[idx] = in[idx] <= x ? 1.0f : 0.0f; } } -void sum(int n, const float *in, float *out) { - int threads_per_block = n > CU1DBLOCK ? CU1DBLOCK : n; - // here, we only need one block - int num_blocks = 1; - __global__ void KernelLT(const int num, const float *in, const float x, - kernel_sum_vec << <num_blocks, threads_per_block>>> (in, out, n); ++__global__ void KernelLT(const size_t num, const float *in, const float x, + float *out) { + for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num; + idx += blockDim.x * gridDim.x) { + out[idx] = in[idx] < x ? 1.0f : 0.0f; + } } -void sum_row(int rows, int cols, int stride, const float *in, float *out) { - int threads_per_block = rows > CU1DBLOCK ? CU1DBLOCK : rows; - int num_blocks = cols; +// ******************************** +// Functions call kernels +// ******************************** - kernel_sum_row << <num_blocks, threads_per_block>>> - (in, out, rows, cols, stride); +void set(const size_t n, const float v, float *out, cudaStream_t s) { + KernelSet <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, v, out); } -void sum_col(int rows, int cols, int stride, const float *in, float *out) { - int threads_per_block = cols > CU1DBLOCK ? CU1DBLOCK : cols; - int num_blocks = rows; +void abs(const size_t n, const float *in, float *out, cudaStream_t s) { + KernelAbs <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out); +} - kernel_sum_col << <num_blocks, threads_per_block>>> - (in, out, rows, cols, stride); +void sign(const size_t n, const float *in, float *out, cudaStream_t s) { + KernelSign <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out); } -void add_row(int rows, int cols, int stride, const float *in_row, - const float *in_mat, float *out) { - dim3 threads_per_block(CU2DBLOCK_X, CU2DBLOCK_Y); - dim3 num_blocks( - cols / threads_per_block.x + (cols % threads_per_block.x == 0 ? 0 : 1), - rows / threads_per_block.y + (rows % threads_per_block.y == 0 ? 0 : 1)); - kernel_add_vec_row << <num_blocks, threads_per_block>>> - (in_row, in_mat, out, rows, cols, stride); + +void exp(const size_t n, const float *in, float *out, cudaStream_t s) { + KernelExp <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out); } -void add(int n, const float *a, const float *b, float *out) { - kernel_add << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n); + +void log(const size_t n, const float *in, float *out, cudaStream_t s) { + KernelLog <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out); } -void sub(int n, const float *a, const float *b, float *out) { - kernel_sub << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n); + +void sqrt(const size_t n, const float *in, float *out, cudaStream_t s) { + KernelSqrt <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out); } -void exp(int n, const float *in, float *out) { - kernel_exp << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); + +void square(const size_t n, const float *in, float *out, cudaStream_t s) { + KernelSquare <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out); } -void log(int n, const float *in, float *out) { - kernel_log << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); +void tanh(const size_t n, const float *in, float *out, cudaStream_t s) { + KernelTanh <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out); } -void sigmoid(int n, const float *in, float *out) { - kernel_sigmoid << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); +void relu(const size_t n, const float *in, float *out, cudaStream_t s) { + KernelRelu <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out); +} - void sigmoid(const int n, const float *in, float *out, cudaStream_t s) { ++void sigmoid(const size_t n, const float *in, float *out, cudaStream_t s) { + KernelSigmoid <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out); +} +void softplus(const size_t n, const float *in, float *out, cudaStream_t s) { + KernelSoftplus <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, out); +} +void clamp(const size_t n, const float low, const float high, const float *in, + float *out, cudaStream_t s) { + KernelClamp <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, low, high, in, out); } -void sigmoid_grad(int n, const float *in, float *out) { - kernel_sigmoid_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); +void pow(const size_t n, const float *in, const float x, float *out, + cudaStream_t s) { + KernelPow <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, x, out); } -void relu(int n, const float *in, float *out) { - kernel_relu << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); +void add(const size_t n, const float *in, const float x, float *out, + cudaStream_t s) { + KernelAdd <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, x, out); } -void relu_grad(int n, const float *in, float *out) { - kernel_relu_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); +void mult(const size_t n, const float *in, const float x, float *out, + cudaStream_t s) { + KernelMult <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in, x, out); } -void tanh(int n, const float *in, float *out) { - kernel_tanh << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); +void div(const size_t n, const float x, const float *in, float *out, + cudaStream_t s) { + KernelDiv <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, x, in, out); } -void tanh_grad(int n, const float *in, float *out) { - kernel_tanh_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); +void threshold(const size_t n, const float x, const float *in, float *out, + cudaStream_t s) { + KernelThreshold <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, x, in, out); } -void softplus(int n, const float *in, float *out) { - kernel_softplus << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); +void gt(const size_t num, const float *in, const float x, float *out, + cudaStream_t s) { + KernelGT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out); +} +void ge(const size_t num, const float *in, const float x, float *out, + cudaStream_t s) { + KernelGE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out); +} +void lt(const size_t num, const float *in, const float x, float *out, + cudaStream_t s) { + KernelLT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out); +} +void le(const size_t num, const float *in, const float x, float *out, + cudaStream_t s) { + KernelLE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out); } -void softplus_grad(int n, const float *in, float *out) { - kernel_softplus_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); +void pow(const size_t n, const float *in1, const float *in2, float *out, + cudaStream_t s) { + KernelPow <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out); } -void square(int n, const float *in, float *out) { - kernel_square << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); +void add(const size_t n, const float *in1, const float *in2, float *out, + cudaStream_t s) { + KernelAdd <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out); } -void square_grad(int n, const float *in, float *out) { - kernel_square_grad << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); +void sub(const size_t n, const float *in1, const float *in2, float *out, + cudaStream_t s) { + KernelSub <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out); } -void sqrt(int n, const float *in, float *out) { - kernel_sqrt << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); +void mult(const size_t n, const float *in1, const float *in2, float *out, + cudaStream_t s) { + KernelMult <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out); } -void pow(int n, const float *a, const float *b, float *out) { - kernel_pow << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n); +void div(const size_t n, const float *in1, const float *in2, float *out, + cudaStream_t s) { + KernelDiv <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, in1, in2, out); } -void mult(int n, const float *a, const float *b, float *out) { - kernel_mult << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n); +void sum(const size_t n, const float *in, float *out, cudaStream_t s) { + int threads_per_block = n > CU1DBLOCK ? CU1DBLOCK : n; + // here, we only need one block + int num_blocks = 1; + KernelSum <<<num_blocks, threads_per_block>>> (n, in, out); +} +/* +void square_grad(int n, const float *in, float *out, cudaStream_t s) { + kernel_square_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); } -void mult(int n, const float *a, const float x, float *out) { - kernel_mult << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, x, out, n); +void tanh_grad(int n, const float *in, float *out, cudaStream_t s) { + kernel_tanh_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); } -void div(int n, const float *a, const float *b, float *out) { - kernel_div << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (a, b, out, n); + +void relu_grad(int n, const float *in, float *out, cudaStream_t s) { + kernel_relu_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); } -void set_value(int n, float v, float *out) { - kernel_set_value << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (out, v, n); + +void sigmoid_grad(int n, const float *in, float *out, cudaStream_t s) { + kernel_sigmoid_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); } -void threshold(int n, float alpha, const float *in, float *out) { - kernel_threshold << <ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, alpha, n); +void softplus_grad(int n, const float *in, float *out, cudaStream_t s) { + kernel_softplus_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); } -// follow the consistency guide for math API -__global__ void KernelDiv(const size_t num, const float alpha, const float *in, - float *out) { - for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num; - idx += blockDim.x * gridDim.x) { - out[idx] = alpha / in[idx]; + +__global__ void kernel_sum_col(const float *src_mat_data, float *dst_vec_data, + int rows, int cols, int stride) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + for (; index < rows; index += num_threads) { + dst_vec_data[index] = 0.0f; + for (int k = 0; k < cols; k++) { + dst_vec_data[index] += src_mat_data[index * stride + k]; + } } } @@@ -485,62 -485,30 +485,62 @@@ __global__ void kernel_sigmoid_grad(con } } -void Set(const size_t num, const float x, float *out, cudaStream_t s) { - KernelSet << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, x, out); + +__global__ void kernel_relu_grad(const float *src_data, float *des_data, + int n) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + for (; index < n; index += num_threads) { + des_data[index] = src_data[index] > 0.0f ? 1.0f : 0.0f; + } } -void Div(const size_t num, float alpha, const float *in, float *out, - cudaStream_t s) { - KernelDiv << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, alpha, in, out); + +__global__ void kernel_tanh_grad(const float *src_data, float *des_data, + int n) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + for (; index < n; index += num_threads) { + des_data[index] = (1.0f - src_data[index] * src_data[index]); + } } -void GT(const size_t num, const float *in, const float x, float *out, - cudaStream_t s) { - KernelGT << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out); + +__global__ void kernel_softplus_grad(const float *src_data, float *des_data, + int n) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + for (; index < n; index += num_threads) { + des_data[index] = 1.0f / (1.0f + expf(-src_data[index])); + } } -void GE(const size_t num, const float *in, const float x, float *out, - cudaStream_t s) { - KernelGE << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out); +__global__ void KernelSquareGrad(const float *src_data, float *des_data, + int n) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + for (; index < n; index += num_threads) { + des_data[index] = 2 * src_data[index]; + } } - __global__ void kernel_softmax_loss(const float *prob, const int *label, -void LT(const size_t num, const float *in, const float x, float *out, - cudaStream_t s) { - KernelLT << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out); ++__global__ void kernel_softmax_loss(const float *prob, const size_t *label, + float *loss, int n, int dim) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + for (; index < n; index += num_threads) { + float prob_of_truth = prob[index * dim + label[index]]; + loss[index] -= std::log(max(prob_of_truth, FLT_MIN)); + } } - __global__ void kernel_softmax_gradient(float *grad, const int *label, int n, -void LE(const size_t num, const float *in, const float x, float *out, - cudaStream_t s) { - KernelLE << <ceil(num / CU1DBLOCKF), CU1DBLOCKF>>> (num, in, x, out); ++__global__ void kernel_softmax_gradient(float *grad, const size_t *label, int n, + int dim, float scale) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + for (; index < n; index += num_threads) { + int pos = index * dim + label[index]; + grad[pos] = (grad[pos] - 1.0f) * scale; + } } +*/ + } // namespace cuda } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5784bff3/src/core/tensor/math_kernel.h ---------------------------------------------------------------------- diff --cc src/core/tensor/math_kernel.h index d8a58a5,5c906a9..444f6ca --- a/src/core/tensor/math_kernel.h +++ b/src/core/tensor/math_kernel.h @@@ -31,66 -31,65 +31,66 @@@ namespace singa // TODO(wangwei) make all function templates. namespace cuda { -void sum(int n, const float *in, float *out); -void sum_row(int rows, int cols, int stride, const float *in, float *out); - -void sum_col(int rows, int cols, int stride, const float *in, float *out); - -void add_row(int rows, int cols, int stride, const float *in_row, - const float *in_mat, float *out); - -void add(int n, const float *a, const float *b, float *out); - -void sub(int n, const float *a, const float *b, float *out); - -void exp(int n, const float *in, float *out); - -void log(int n, const float *in, float *out); - -void sigmoid(int n, const float *in, float *out); - -void sigmoid_grad(int n, const float *in, float *out); - -void relu(int n, const float *in, float *out); - -void relu_grad(int n, const float *in, float *out); - -void tanh(int n, const float *in, float *out); - -void tanh_grad(int n, const float *in, float *out); +// 0 input +void set(const size_t n, const float v, float *out, cudaStream_t s); + +// 1 input +void abs(const size_t n, const float *in, float *out, cudaStream_t s); +void sign(const size_t n, const float *in, float *out, cudaStream_t s); +void exp(const size_t n, const float *in, float *out, cudaStream_t s); +void log(const size_t n, const float *in, float *out, cudaStream_t s); +void sqrt(const size_t n, const float *in, float *out, cudaStream_t s); +void square(const size_t n, const float *in, float *out, cudaStream_t s); +void tanh(const size_t n, const float *in, float *out, cudaStream_t s); +void relu(const size_t n, const float *in, float *out, cudaStream_t s); - void sigmoid(const int n, const float *in, float *out, cudaStream_t s); ++void sigmoid(const size_t n, const float *in, float *out, cudaStream_t s); +void softplus(const size_t n, const float *in, float *out, cudaStream_t s); +void clamp(const size_t n, const float low, const float high, const float *in, + float *out, cudaStream_t s); + +void pow(const size_t n, const float *in, const float x, float *out, + cudaStream_t s); -void softplus(int n, const float *in, float *out); +void add(const size_t n, const float *in, const float x, float *out, + cudaStream_t s); -void softplus_grad(int n, const float *in, float *out); +void mult(const size_t n, const float *in, const float x, float *out, + cudaStream_t s); -void square(int n, const float *in, float *out); +void div(const size_t n, const float x, const float *in, float *out, + cudaStream_t s); -void square_grad(int n, const float *in, float *out); +void threshold(const size_t n, const float x, const float *in, float *out, + cudaStream_t s); -void sqrt(int n, const float *in, float *out); +void gt(const size_t num, const float *in, const float x, float *out, + cudaStream_t s); +void ge(const size_t num, const float *in, const float x, float *out, + cudaStream_t s); +void lt(const size_t num, const float *in, const float x, float *out, + cudaStream_t s); +void le(const size_t num, const float *in, const float x, float *out, + cudaStream_t s); -void pow(int n, const float *a, const float *b, float *out); +// 2 inputs +void pow(const size_t n, const float *in1, const float *in2, float *out, + cudaStream_t s); -void mult(int n, const float *a, const float *b, float *out); +void add(const size_t n, const float *in1, const float *in2, float *out, + cudaStream_t s); -void mult(int n, const float *a, const float x, float *out); +void sub(const size_t n, const float *in1, const float *in2, float *out, + cudaStream_t s); -void div(int n, const float *a, const float *b, float *out); +void mult(const size_t n, const float *in1, const float *in2, float *out, + cudaStream_t s); -void set_value(int n, float v, float *out); +void div(const size_t n, const float *in1, const float *in2, float *out, + cudaStream_t s); -void threshold(int n, float alpha, const float *in, float *out); +void sum(const size_t n, const float *in, float *out, cudaStream_t s); -// follow the consistency guide for math API -void Div(const size_t num, const float x, const float *in, float *out, - cudaStream_t s); -void Set(const size_t num, const float x, float *out, cudaStream_t s); -void GT(size_t num, const float *in, const float x, float *out, cudaStream_t s); -void GE(size_t num, const float *in, const float x, float *out, cudaStream_t s); -void LT(size_t num, const float *in, const float x, float *out, cudaStream_t s); -void LE(size_t num, const float *in, const float x, float *out, cudaStream_t s); } // cuda } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5784bff3/src/core/tensor/tensor.cc ---------------------------------------------------------------------- diff --cc src/core/tensor/tensor.cc index e62386a,5ae375c..e6917d8 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@@ -639,92 -701,4 +639,91 @@@ void SumRows(const Tensor &M, Tensor *v Mult(X, one, v); } } +// ====================Random operations===================================== +template <typename SType> +void Bernoulli(const SType p, Tensor *out) { + TYPE_LANG_SWITCH(out->data_type(), DType, out->device()->lang(), Lang, { + auto prob = TypeCast<SType, DType>(p); + out->device()->Exec([prob, out](Context *ctx) { + Bernoulli<DType, Lang>(out->Size(), prob, out->blob(), ctx); + }, {}, {out->blob()}, true); + }); +} +template void Bernoulli<float>(const float p, Tensor *out); + +template <typename SType> +void Uniform(const SType low, const SType high, Tensor *out) { + TYPE_LANG_SWITCH(out->data_type(), DType, out->device()->lang(), Lang, { + auto l = TypeCast<SType, DType>(low); + auto h = TypeCast<SType, DType>(high); + out->device()->Exec([l, h, out](Context *ctx) { + Uniform<DType, Lang>(out->Size(), l, h, out->blob(), ctx); + }, {}, {out->blob()}, true); + }); +} +template void Uniform<float>(const float low, const float high, Tensor *out); + +template <typename SType> +void Gaussian(const SType mean, const SType std, Tensor *out) { + TYPE_LANG_SWITCH(out->data_type(), DType, out->device()->lang(), Lang, { + auto m = TypeCast<SType, DType>(mean); + auto s = TypeCast<SType, DType>(std); + out->device()->Exec([m, s, out](Context *ctx) { + Gaussian<DType, Lang>(out->Size(), m, s, out->blob(), ctx); + }, {}, {out->blob()}, true); + }); +} +template void Gaussian<float>(const float mean, const float std, Tensor *out); + +// ================Blas operations============================================ +template <typename SType> +void Axpy(const SType alpha, const Tensor &in, Tensor *out) { + TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, { + auto a = TypeCast<SType, DType>(alpha); + out->device()->Exec([a, in, out](Context *ctx) { + Axpy<DType, Lang>(in.Size(), a, in.blob(), out->blob(), ctx); + }, {in.blob(), out->blob()}, {out->blob()}); + }); +} - template <> - void Axpy(const float alpha, const Tensor &in, Tensor *out); ++template void Axpy(const float alpha, const Tensor &in, Tensor *out); + +Tensor Mult(const Tensor &A, const Tensor &B) { + Shape s; + s.push_back(A.shape(0)); + if (B.nDim() == 2) s.push_back(B.shape(1)); + Tensor out(s, A.device(), A.data_type()); + Mult(A, B, &out); + return out; +} + +void Mult(const Tensor &A, const Tensor &B, Tensor *out) { + Mult(1.0f, A, B, 0.0f, out); +} + +template <typename SType> +void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta, + Tensor *C) { + CHECK_EQ(A.shape().size(), 2u); + if (B.nDim() == 1u) { + TYPE_LANG_SWITCH(A.data_type(), DType, A.device()->lang(), Lang, { + auto a = TypeCast<SType, DType>(alpha); + auto b = TypeCast<SType, DType>(beta); + C->device()->Exec([a, A, b, B, C](Context *ctx) { + GEMV<DType, Lang>(A.transpose(), A.shape(0), A.shape(1), a, A.blob(), + B.blob(), b, C->blob(), ctx); + }, {A.blob(), B.blob()}, {C->blob()}); + }); + } else { + CHECK(!C->transpose()); + TYPE_LANG_SWITCH(A.data_type(), DType, A.device()->lang(), Lang, { + auto a = TypeCast<SType, DType>(alpha); + auto b = TypeCast<SType, DType>(beta); + C->device()->Exec([a, A, b, B, C](Context *ctx) { + GEMM<DType, Lang>(A.transpose(), B.transpose(), A.shape(0), B.shape(1), + A.shape(1), a, A.blob(), B.blob(), b, C->blob(), ctx); + }, {A.blob(), B.blob()}, {C->blob()}); + }); + } +} + } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5784bff3/src/model/optimizer/adagrad.cc ---------------------------------------------------------------------- diff --cc src/model/optimizer/adagrad.cc index 0000000,8bdb07c..0b8ec88 mode 000000,100644..100644 --- a/src/model/optimizer/adagrad.cc +++ b/src/model/optimizer/adagrad.cc @@@ -1,0 -1,35 +1,36 @@@ + /** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #ifndef SRC_MODEL_OPTIMIZER_ADAGRAD_H_ + #define SRC_MODEL_OPTIMIZER_ADAGRAD_H_ + #include "singa/model/optimizer.h" + #include <functional> + namespace singa { + + void Adagrad::Setup(const OptimizerConf& conf) { delta_ = conf.delta(); } + + void Adagrad::Apply(int step, float lr, const string& name, Tensor* grad, + Tensor* value) { + if (history_gradient_.find(name) == history_gradient_.end()) + history_gradient_[name].ResetLike(*value); + Tensor& history = history_gradient_[name]; - history += (*grad) * (*grad); - (*value) -= (*grad) * lr / Sqrt(history + delta_); ++ history += Square(*grad); ++ (*grad) /= Sqrt(history + delta_); ++ Axpy(-lr, *grad, value); + } + } // namespace singa + #endif // SRC_MODEL_OPTIMIZER_ADAGRAD_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5784bff3/src/model/optimizer/rmsprop.cc ---------------------------------------------------------------------- diff --cc src/model/optimizer/rmsprop.cc index 0000000,cad333c..7b9934c mode 000000,100644..100644 --- a/src/model/optimizer/rmsprop.cc +++ b/src/model/optimizer/rmsprop.cc @@@ -1,0 -1,38 +1,41 @@@ + /** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #ifndef SRC_MODEL_OPTIMIZER_ADAGRAD_H_ + #define SRC_MODEL_OPTIMIZER_ADAGRAD_H_ + #include "singa/model/optimizer.h" + #include <functional> + namespace singa { + + void RMSProp::Setup(const OptimizerConf& conf) { + delta_ = conf.delta(); - rho_ = conf.delta(); ++ rho_ = conf.rho(); + } + + void RMSProp::Apply(int step, float lr, const string& name, Tensor* grad, + Tensor* value) { - if (history_gradient_.find(name) == history_gradient_.end()) ++ if (history_gradient_.find(name) == history_gradient_.end()) { + history_gradient_[name].ResetLike(*value); ++ } + Tensor& history = history_gradient_[name]; - history = history * rho_ + (*grad) * (*grad) * (1 - rho_); - (*value) -= (*grad) * lr / Sqrt(history + delta_); ++ history *= rho_; ++ Axpy(1 - rho_, Square(*grad), &history); ++ (*grad) /= Sqrt(history + delta_); ++ Axpy(-lr, *grad, value); + } + } // namespace singa + #endif // SRC_MODEL_OPTIMIZER_ADAGRAD_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5784bff3/src/proto/model.proto ---------------------------------------------------------------------- diff --cc src/proto/model.proto index d368296,c26aa35..ca6f0cd --- a/src/proto/model.proto +++ b/src/proto/model.proto @@@ -86,6 -86,9 +86,9 @@@ message OptimizerConf // used by vanilla sgd and nesterov optional float momentum = 5 [default = 0.9]; + + // delta is used to avoid dividing zero - optional float delta = 6 [default = 0.0000001]; ++ optional float delta = 6 [default = 1e-8]; } message ConstraintConf { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5784bff3/test/singa/test_adagrad.cc ---------------------------------------------------------------------- diff --cc test/singa/test_adagrad.cc index 0000000,1382467..80240b1 mode 000000,100644..100644 --- a/test/singa/test_adagrad.cc +++ b/test/singa/test_adagrad.cc @@@ -1,0 -1,92 +1,96 @@@ + /************************************************************ + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + *************************************************************/ + + #include "gtest/gtest.h" + #include "singa/model/optimizer.h" + #include "singa_config.h" + #include <cmath> + + TEST(Adagrad, ApplyCPU) { + singa::Adagrad adagrad; + float lr = 0.1f; + const float v[4] = {0.1, 0.2, 0.3, 0.4}; + const float g[4] = {0.01, 0.02, 0.03, 0.04}; + + singa::Tensor value(singa::Shape{4}), grad(singa::Shape{4}); + value.CopyDataFromHostPtr(v, 4); + grad.CopyDataFromHostPtr(g, 4); + ++ singa::OptimizerConf conf; ++ adagrad.Setup(conf); + adagrad.Apply(0, lr, "xx", &grad, &value); + + singa::Tensor v1 = value.Clone(); + const float* newv1 = v1.data<const float*>(); + float history[4]; + for (int i = 0; i < 4; ++i) history[i] = g[i] * g[i]; + for (int i = 0; i < 4; ++i) - EXPECT_FLOAT_EQ(newv1[i], - v[i] - lr * g[i] / sqrt(history[i] + (float)1E-8)); ++ EXPECT_NEAR(newv1[i], v[i] - lr * g[i] / sqrt(history[i] + conf.delta()), ++ 1e-5); + + grad.CopyDataFromHostPtr(g, 4); + adagrad.Apply(1, lr, "xx", &grad, &value); + singa::Tensor v2 = value.Clone(); + const float* newv2 = v2.data<const float*>(); + for (int i = 0; i < 4; ++i) history[i] += g[i] * g[i]; + + for (int i = 0; i < 4; ++i) - EXPECT_FLOAT_EQ(newv2[i], - newv1[i] - lr * g[i] / sqrt(history[i] + (float)1E-8)); ++ EXPECT_NEAR(newv2[i], ++ newv1[i] - lr * g[i] / sqrt(history[i] + conf.delta()), 1e-5); + } + + #ifdef USE_CUDA + TEST(Adagrad, ApplyCUDA) { + singa::Adagrad adagrad; + float lr = 0.1f; + const float v[4] = {0.1, 0.2, 0.3, 0.4}; + const float g[4] = {0.01, 0.02, 0.03, 0.04}; + + singa::CudaGPU dev; + singa::Tensor value(singa::Shape{4}, &dev), grad(singa::Shape{4}, &dev); + value.CopyDataFromHostPtr(v, 4); + grad.CopyDataFromHostPtr(g, 4); + ++ singa::OptimizerConf conf; ++ adagrad.Setup(conf); + adagrad.Apply(0, lr, "xx", &grad, &value); + + singa::Tensor v1 = value.Clone(); + v1.ToHost(); + const float* newv1 = v1.data<const float*>(); + float history[4]; + for (int i = 0; i < 4; ++i) history[i] = g[i] * g[i]; + for (int i = 0; i < 4; ++i) - EXPECT_FLOAT_EQ(newv1[i], - v[i] - lr * g[i] / sqrt(history[i] + (float)1E-8)); ++ EXPECT_NEAR(newv1[i], v[i] - lr * g[i] / sqrt(history[i] + conf.delta()), ++ 1e-5); + + grad.CopyDataFromHostPtr(g, 4); + adagrad.Apply(1, lr, "xx", &grad, &value); + singa::Tensor v2 = value.Clone(); + v2.ToHost(); + const float* newv2 = v2.data<const float*>(); + for (int i = 0; i < 4; ++i) history[i] += g[i] * g[i]; + + for (int i = 0; i < 4; ++i) + EXPECT_FLOAT_EQ(newv2[i], - newv1[i] - lr * g[i] / sqrt(history[i] + (float)1E-8)); ++ newv1[i] - lr * g[i] / sqrt(history[i] + conf.delta())); + } + #endif http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5784bff3/test/singa/test_rmsprop.cc ---------------------------------------------------------------------- diff --cc test/singa/test_rmsprop.cc index 0000000,62101f7..8104f50 mode 000000,100644..100644 --- a/test/singa/test_rmsprop.cc +++ b/test/singa/test_rmsprop.cc @@@ -1,0 -1,103 +1,106 @@@ + /************************************************************ + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + *************************************************************/ + + #include "gtest/gtest.h" + #include "singa/model/optimizer.h" + #include "singa_config.h" + #include <cmath> + + TEST(RMSProp, ApplyCPU) { + singa::RMSProp rmsprop; + float lr = 0.1f; - float rho = 0.002f; ++ float rho = 0.9; + const float v[4] = {0.1, 0.2, 0.3, 0.4}; + const float g[4] = {0.01, 0.02, 0.03, 0.04}; + + singa::OptimizerConf conf; + conf.set_rho(rho); ++ conf.set_delta(1E-8); + + singa::Tensor value(singa::Shape{4}), grad(singa::Shape{4}); + value.CopyDataFromHostPtr(v, 4); + grad.CopyDataFromHostPtr(g, 4); + + rmsprop.Setup(conf); + rmsprop.Apply(0, lr, "xx", &grad, &value); + + singa::Tensor v1 = value.Clone(); + const float* newv1 = v1.data<const float*>(); + float history[4]; + for (int i = 0; i < 4; ++i) history[i] = g[i] * g[i] * (1 - rho); + for (int i = 0; i < 4; ++i) - EXPECT_FLOAT_EQ(newv1[i], - v[i] - lr * g[i] / sqrt(history[i] + (float)1E-8)); ++ EXPECT_NEAR(newv1[i], v[i] - g[i] * lr / sqrt(history[i] + (float)1E-8), ++ 1e-5); + + grad.CopyDataFromHostPtr(g, 4); + rmsprop.Apply(1, lr, "xx", &grad, &value); + singa::Tensor v2 = value.Clone(); + const float* newv2 = v2.data<const float*>(); + for (int i = 0; i < 4; ++i) - history[i] += history[i] * rho + g[i] * g[i] * (1 - rho); ++ history[i] = history[i] * rho + g[i] * g[i] * (1 - rho); + + for (int i = 0; i < 4; ++i) - EXPECT_FLOAT_EQ(newv2[i], - newv1[i] - lr * g[i] / sqrt(history[i] + (float)1E-8)); ++ EXPECT_NEAR(newv2[i], newv1[i] - lr * g[i] / sqrt(history[i] + (float)1E-8), ++ 1e-5); + } + + #ifdef USE_CUDA + TEST(RMSProp, ApplyCUDA) { + singa::RMSProp rmsprop; + float lr = 0.1f; - float rho = 0.002f; ++ float rho = 0.02; + const float v[4] = {0.1, 0.2, 0.3, 0.4}; + const float g[4] = {0.01, 0.02, 0.03, 0.04}; + + singa::OptimizerConf conf; + conf.set_rho(rho); ++ conf.set_delta(1e-8); + + singa::CudaGPU dev; + singa::Tensor value(singa::Shape{4}, &dev), grad(singa::Shape{4}, &dev); + value.CopyDataFromHostPtr(v, 4); + grad.CopyDataFromHostPtr(g, 4); + ++ rmsprop.Setup(conf); + rmsprop.Apply(0, lr, "xx", &grad, &value); + + singa::Tensor v1 = value.Clone(); + v1.ToHost(); + const float* newv1 = v1.data<const float*>(); + float history[4]; + for (int i = 0; i < 4; ++i) history[i] = g[i] * g[i] * (1 - rho); + for (int i = 0; i < 4; ++i) - EXPECT_FLOAT_EQ(newv1[i], - v[i] - lr * g[i] / sqrt(history[i] + (float)1E-8)); ++ EXPECT_NEAR(newv1[i], v[i] - lr * g[i] / sqrt(history[i] + conf.delta()), ++ 1e-5); + + grad.CopyDataFromHostPtr(g, 4); + rmsprop.Apply(1, lr, "xx", &grad, &value); + singa::Tensor v2 = value.Clone(); + v2.ToHost(); + const float* newv2 = v2.data<const float*>(); + for (int i = 0; i < 4; ++i) - history[i] += history[i] * rho + g[i] * g[i] * (1 - rho); ++ history[i] = history[i] * rho + g[i] * g[i] * (1 - rho); + + for (int i = 0; i < 4; ++i) - EXPECT_FLOAT_EQ(newv2[i], - newv1[i] - lr * g[i] / sqrt(history[i] + (float)1E-8)); ++ EXPECT_NEAR(newv2[i], ++ newv1[i] - lr * g[i] / sqrt(history[i] + conf.delta()), 1e-5); + } + #endif
