SINGA-80 New Blob Level and Address Level Math Operation Interface ----
specified MVAdd and MVSum in row/col direction and add transpose support todo: add hard transpose function to transpose the matrix in storage level Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/d333cbad Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/d333cbad Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/d333cbad Branch: refs/heads/master Commit: d333cbad02dc68f70ffba00b71b08f19cd959506 Parents: 98f5256 Author: jinyangturbo <[email protected]> Authored: Mon Nov 9 06:10:09 2015 -0800 Committer: jinyangturbo <[email protected]> Committed: Mon Nov 9 06:10:09 2015 -0800 ---------------------------------------------------------------------- include/singa/utils/blob.h | 9 +- include/singa/utils/math_blob.h | 157 ++++++++++++++++++++++++++--------- 2 files changed, 122 insertions(+), 44 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d333cbad/include/singa/utils/blob.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/blob.h b/include/singa/utils/blob.h index f7c29af..d20f318 100644 --- a/include/singa/utils/blob.h +++ b/include/singa/utils/blob.h @@ -251,8 +251,8 @@ class Blob { CHECK(data_); return static_cast<Dtype*>(data_->mutable_gpu_data()); } - inline void set_transpose() { - transpose_ = true; + inline void set_transpose(bool val) { + transpose_ = val; } inline bool transpose() const { return transpose_; @@ -325,12 +325,13 @@ Blob<Dtype>* Reshape(const Blob<Dtype> & A, int dim0, int dim1, int dim2, /** * @return a new Blob which share all internal members with the input Blob - * except that the transpose_ field is set to true. + * except that the transpose_ field is set to the opposite value. */ template <typename Dtype> Blob<Dtype>* Transpose(const Blob<Dtype> & A) { Blob<Dtype>* res = new Blob<Dtype>(A); - res->set_transpose(); + bool origin = A.transpose(); + res->set_transpose(!origin); return res; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d333cbad/include/singa/utils/math_blob.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/math_blob.h b/include/singa/utils/math_blob.h index 2d53d32..1aa4be0 100644 --- a/include/singa/utils/math_blob.h +++ b/include/singa/utils/math_blob.h @@ -23,6 +23,7 @@ #define SINGA_UTILS_MATH_BLOB_H_ #include <vector> +#include <algorithm> #include "singa/utils/blob.h" #include "singa/utils/singa_op.h" #include "singa/utils/math_addr.h" @@ -115,8 +116,8 @@ void GEMV(XPU, xpu, Dtype alpha, Dtype beta, const Blob<Dtype>& A, * @param[out] C output vector */ template <typename Dtype> -void MVDot(XPU xpu, const Blob<Dtype>& A, const Blob<Dtype>& B, Blob<Dtype>* C) -{ +void MVDot(XPU xpu, const Blob<Dtype>& A, const Blob<Dtype>& B, + Blob<Dtype>* C) { GEMV(xpu, Dtype(1), Dtype(0), A, B, C); } @@ -179,8 +180,8 @@ void GEMM(XPU xpu, Dtype alpha, Dtype beta, const Blob<Dtype>& A, * @param[out] C output matrix */ template <typename Dtype> -void MMDot(XPU xpu, const Blob<Dtype>& A, const Blob<Dtype>& B, Blob<Dtype>* C) -{ +void MMDot(XPU xpu, const Blob<Dtype>& A, const Blob<Dtype>& B, + Blob<Dtype>* C) { GEMM(xpu, Dtype(1), Dtype(0), A, B, C); } @@ -320,9 +321,9 @@ void Map(XPU xpu, Dtype alpha, const Blob<Dtype>& A, const Blob<Dtype>& B, template<typename Dtype> void Copy(XPU xpu, const Blob<Dtype>& A, const Blob<Dtype>* B) { CHECK_EQ(A.count(), B->count()) << "Blobs must have the same size"; - if (xpu == cpu) + if (xpu == cpu) { std::copy(A.cpu_data(), A.cpu_data() + A.count(), B->mutable_cpu_data()); - else { + } else { LOG(FATAL) << "Not implemented"; } } @@ -372,82 +373,158 @@ void Div(XPU xpu, const Blob<Dtype> & A, const Blob<Dtype> & B, } /*************************1D<-->2D op/transform***************************/ /** - * Add each row of B with A, i.e., Bij = alpha*Ai + beta*Bij + * Add A to each column of B, i.e., Bij = alpha*Ai + beta*Bij * Loose shape checking, B.count() % A.count() == 0. - * # rows of B = B.count() / A.count(). - * Transpose is disabled. + * # columns of B = B.count() / A.count(). */ template<typename Dtype> -void MVAdd(XPU xpu, Dtype alpha, Dtype beta, const Blob<Dtype> & A, +void MVAddCol(XPU xpu, Dtype alpha, Dtype beta, const Blob<Dtype> & A, Blob<Dtype> * B) { - CHECK_EQ(B.count() % A.count(), 0) << "#col of B not match length of A"; - int m = A.count(), n = B->count() / m; - if (xpu == cpu) { - Blob<Dtype> one(n); - one.SetValue(1); - cpu_gemm(A.cpu_data(), one.cpu_data(), m, n, 1, alpha, beta, - false, false, B->mutable_cpu_data()); - } + if (B.transpose()) { + Blob<Dtype>* tmp = Transpose(* B); + MVAddRow(xpu, alpha, beta, A, tmp); + delete tmp; + } else { + CHECK_EQ(B.count() % A.count(), 0) << "#col of B not match length of A"; + int m = A.count(), n = B->count() / m; + if (xpu == cpu) { + Blob<Dtype> one(n); + one.SetValue(1); + cpu_gemm(A.cpu_data(), one.cpu_data(), m, n, 1, alpha, beta, + false, false, B->mutable_cpu_data()); + } #ifdef USE_GPU - if (xpu == gpu) { - singa_gpu_add_vec_row(B->gpu_data(), - A.gpu_data(), A.gpu_data(), m, n, n); - // gpu part + if (xpu == gpu) { + singa_gpu_add_vec_row(B->gpu_data(), + A.gpu_data(), A.gpu_data(), m, n, n); + // gpu part + } +#endif // USE_GPU } +} +/** + * Add A to each column of B, i.e., Bij = Ai + Bij + * Loose shape checking, B.count() % A.count() == 0. + * # columns of B = B.count() / A.count(). + */ +template<typename Dtype> +void MVAddCol(XPU xpu, const Blob<Dtype> & A, Blob<Dtype>* B) { + MVAddCol(xpu, Dtype(1), Dtype(1), A, B); +} + +/** + * Add A to each row of B, i.e., Bij = alpha*Aj + beta*Bij + * Loose shape checking, B.count() % A.count() == 0. + * # rows of B = B.count() / A.count(). + */ +template<typename Dtype> +void MVAddRow(XPU xpu, Dtype alpha, Dtype beta, const Blob<Dtype> & A, + Blob<Dtype> * B) { + if (B.transpose()) { + Blob<Dtype>* tmp = Transpose(* B); + MVAddCol(xpu, alpha, beta, A, tmp); + delete tmp; + } else { + CHECK_EQ(B.count() % A.count(), 0) << "#col of B not match length of A"; + int m = A.count(), n = B->count() / m; + if (xpu == cpu) { + Blob<Dtype> one(n); + one.SetValue(1); + cpu_gemm(one.cpu_data(), A.cpu_data(), n, m, 1, alpha, beta, + false, false, B->mutable_cpu_data()); + } +#ifdef USE_GPU + if (xpu == gpu) { + // gpu part + } #endif // USE_GPU + } } /** - * Add each row of B with A, i.e., Bij = Ai + Bij + * Add A to each row of B, i.e., Bij = Aj + Bij * Loose shape checking, B.count() % A.count() == 0. * # rows of B = B.count() / A.count(). - * Transpose is disabled. */ template<typename Dtype> -void MVAdd(XPU xpu, const Blob<Dtype> & A, Blob<Dtype>* B) { - MVAdd(xpu, Dtype(1), Dtype(1), A, B); +void MVAddRow(XPU xpu, const Blob<Dtype> & A, Blob<Dtype>* B) { + MVAddRow(xpu, Dtype(1), Dtype(1), A, B); } /** - * Copy A to each row of B + * Copy A to each column of B, i.e., Bij = Ai + * Loose shape checking, B.count() % A.count() == 0, + * # columns of B = B.count() / A.count(). + */ +template<typename Dtype> +void RepmatCol(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) { + MVAddCol(xpu, Dtype(1), Dtype(0), A, B); +} + +/** + * Copy A to each row of B, i.e., Bij = Aj * Loose shape checking, B.count() % A.count() == 0, * # rows of B = B.count() / A.count(). - * Transpose is disabled. */ template<typename Dtype> -void Repmat(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) { - MVAdd(xpu, Dtype(1), Dtype(0), A, B); +void RepmatRow(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) { + MVAddRow(xpu, Dtype(1), Dtype(0), A, B); } /** - * Add each col of matrix A to vector B, i.e., Bi = \sum_j {alpha*Aij}+beta*Bi + * Sum all columns of matrix A to a column vector B, + * i.e., Bi = \sum_j {alpha*Aij}+beta*Bi * Loose shape checking, A.count() % B.count() == 0. - * # rows of A = A.count() / B.count(). - * Transpose is disabled. + * # columns of A = A.count() / B.count(). */ template<typename Dtype> -void MVSum(XPU xpu, Dtype alpha, Dtype beta, const Blob<Dtype> & A, +void MVSumCol(XPU xpu, Dtype alpha, Dtype beta, const Blob<Dtype> & A, Blob<Dtype> * B) { CHECK_EQ(A.count() % B->count(), 0) << "length of B must = # of cols of A"; - int m = B->count(), n = A.count() / m; if (xpu == cpu) { Blob<Dtype> one(n); one.SetValue(1); cpu_gemm(A.cpu_data(), one.cpu_data(), m, 1, n, alpha, beta, - false, false, B->mutable_cpu_data()); + A.transpose(), false, B->mutable_cpu_data()); } #ifdef USE_GPU if (xpu == gpu) { singa_gpu_sum_col(A.gpu_data(), B->gpu_data(), m, n, n); - // gpu part + // gpu part (TODO check transpose case) } #endif // USE_GPU } + /** - * Reduce each row of A to an element of B. + * Sum all rows of matrix A to a row vector B, + * i.e., Bj = \sum_i {alpha*Aij}+beta*Bj * Loose shape checking, A.count() % B.count() == 0. * # rows of A = A.count() / B.count(). */ +template<typename Dtype> +void MVSumCol(XPU xpu, Dtype alpha, Dtype beta, const Blob<Dtype> & A, + Blob<Dtype> * B) { + CHECK_EQ(A.count() % B->count(), 0) << "length of B must = # of cols of A"; + int m = B->count(), n = A.count() / m; + if (xpu == cpu) { + Blob<Dtype> one(n); + one.SetValue(1); + cpu_gemm(one.cpu_data(), A.cpu_data(), 1, m, n, alpha, beta, + A.transpose(), false, B->mutable_cpu_data()); + } +#ifdef USE_GPU + if (xpu == gpu) { + singa_gpu_sum_col(A.gpu_data(), B->gpu_data(), m, n, n); + // gpu part (TODO check transpose case) + } +#endif // USE_GPU +} + +/** + * Reduce each row of A to an element of B. + * Loose shape checking, A.count() % B.count() == 0. + * # columns of A = A.count() / B.count(). + */ template<typename Op, typename Dtype> void Reduce2D(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) { CHECK_EQ(A.count() % B.count(), 0) << "Row size not match B length"; @@ -465,7 +542,7 @@ void Reduce2D(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) { /** * Duplicate each element of A into a row of B. * Loose shape checking, B.count() % A.count() == 0. - * # rows of B = B.count() / A.count(). + * # columns of B = B.count() / A.count(). */ template<typename Op, typename Dtype> void Expand2D(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) { @@ -484,4 +561,4 @@ void Expand2D(XPU xpu, const Blob<Dtype> & A, Blob<Dtype> * B) { } // end of namespace singa -#endif // SINGA_BLOB_MATH_BLOB_H_ +#endif // SINGA_UTILS_MATH_BLOB_H_
