SINGA-80 New Blob Level and Address Level Math Operation Interface

---

header files


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/bbb7dbc6
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/bbb7dbc6
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/bbb7dbc6

Branch: refs/heads/master
Commit: bbb7dbc6adff793d8cee568e4b8ff82818779690
Parents: b99de6c
Author: jinyangturbo <[email protected]>
Authored: Thu Oct 22 01:11:22 2015 -0700
Committer: Wei Wang <[email protected]>
Committed: Mon Nov 9 17:04:48 2015 +0800

----------------------------------------------------------------------
 include/singa/blob/math_addr.cc |  33 ---
 include/singa/blob/math_addr.h  |  78 +++---
 include/singa/blob/math_blob.h  | 479 +++++++++++++++++++++++++++++------
 include/singa/blob/singa_op.h   |  22 +-
 include/singa/blob/test.cc      | 165 ------------
 include/singa/utils/blob.h      |  16 ++
 6 files changed, 473 insertions(+), 320 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bbb7dbc6/include/singa/blob/math_addr.cc
----------------------------------------------------------------------
diff --git a/include/singa/blob/math_addr.cc b/include/singa/blob/math_addr.cc
deleted file mode 100644
index a03ce60..0000000
--- a/include/singa/blob/math_addr.cc
+++ /dev/null
@@ -1,33 +0,0 @@
-extern "C"
-{
-   #include <cblas.h>
-}
-
-#include "math_addr.h"
-#include "singa_op.h"
-
-void cpu_gemm(const float * A, const float * B, const int m, const int n, 
const int k, const float alpha, const float beta, const bool TranA, const bool 
TranB, float * C)
-{
-                int lda, ldb;
-                CBLAS_TRANSPOSE tA, tB;
-                lda = TranA ? m : k;
-                ldb = TranB ? k : n;
-                tA = TranA ? CblasTrans : CblasNoTrans;
-                tB = TranB ? CblasTrans : CblasNoTrans;
-                cblas_sgemm(CblasRowMajor, tA, tB, m, n, k, alpha, A, lda, B, 
ldb, beta, C, n);
-}
-
-void cpu_gemv(const float * A, const float * B, const int m, const int n, 
const float alpha, const float beta, const bool TranA, float * C)
-{
-                CBLAS_TRANSPOSE tA;
-                tA = TranA ? CblasTrans : CblasNoTrans;
-                cblas_sgemv(CblasRowMajor, tA, m, n, alpha, A, n, B, 1, beta, 
C, 1);
-}
-
-void cpu_axpy(const float * A, const int n, const float alpha, float * B)
-{
-                cblas_saxpy(n, alpha, A, 1, B, 1);
-}
-
-
-

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bbb7dbc6/include/singa/blob/math_addr.h
----------------------------------------------------------------------
diff --git a/include/singa/blob/math_addr.h b/include/singa/blob/math_addr.h
index c732060..a6663ab 100644
--- a/include/singa/blob/math_addr.h
+++ b/include/singa/blob/math_addr.h
@@ -1,20 +1,26 @@
 #ifndef MATH_ADDR_H
 #define MATH_ADDR_H
-
-void cpu_gemm(const float * A, const float * B, const int m, const int n, 
const int k, const float alpha, const float beta, const bool TranA, const bool 
TranB, float * C);
-
+
+namespace singa{
+
+const float * cpu_uni_vec(const int n);
+
+void cpu_gemm(const float * A, const float * B, const int m, const int n, 
const int k, const float alpha, const float beta, const bool TranA, const bool 
TranB, float * C);
+
 void cpu_gemv(const float * A, const float * B, const int m, const int n, 
const float alpha, const float beta, const bool TranA, float * C);
-// should be very careful : m is the length of B, and n is the length of C , A 
is a n*m matrix
-
-void cpu_axpy(const float * A, const int n, const float alpha, float * B);
-
+// should be very careful : m is the length of B, and n is the length of C , A 
is a n*m matrix
+
+void cpu_axpy(const float * A, const int n, const float alpha, float * B);
+
+float cpu_dot(const float * A, const float * B, const int n);
+
 /*
-//element-wise
-template<typename Op> void cpu_e_f(const int n, const float alpha, float * A);
-template<typename Op> void cpu_e_f(const int n,const float * A,const float 
alpha, float * B);
-template<typename Op> void cpu_e_f(const int n,const float * A,const float * 
B,const float alpha, const float beta,float * C);
+//element-wise
+template<typename Op> void cpu_e_f(const int n, const float alpha, float * A);
+template<typename Op> void cpu_e_f(const int n,const float * A,const float 
alpha, float * B);
+template<typename Op> void cpu_e_f(const int n,const float * A,const float * 
B,const float alpha, const float beta,float * C);
 // element-wise generalized operation defined in Op
-*/
+*/
 
 //element-wise
 template<typename Op> void cpu_e_f(const int n, const float alpha, float * A)
@@ -41,13 +47,13 @@ template<typename Op> void cpu_e_f(const int n,const float 
* A,const float * B,c
                 }
 }
 // element-wise generalized operation defined in Op
-
+
 /*
-//matrix/vector expand/reduce
-
-template<typename Op> void cpu_reduce_f(const float * A,const int m, const int 
n, float * B);
-//reduce each row of A to an element of B e.g. the sum operation in softmax
-template<typename Op> void cpu_expand_f(const float * A,const int m, const int 
n, float * B);
+//matrix/vector expand/reduce
+
+template<typename Op> void cpu_reduce_f(const float * A,const int m, const int 
n, float * B);
+//reduce each row of A to an element of B e.g. the sum operation in softmax
+template<typename Op> void cpu_expand_f(const float * A,const int m, const int 
n, float * B);
 //expand each element in A into a row of B
 */
 
@@ -68,23 +74,25 @@ template<typename Op> void cpu_expand_f(const float * 
A,const int m, const int n
                                 Op::Map(A[i], n, B+i*n);
                 }
 }
-//expand each element in A into a row of B
-
-void gpu_gemm(const float * A, const float * B, const int m, const int n, 
const int k, const float alpha, const float beta, const bool TranA, const bool 
TranB, float * C);
-void gpu_gemv(const float * A, const float * B, const int m, const int n, 
const float alpha, const float beta, const bool TranA, float * C);
-void gpu_axpy(const float * A, const int n, const float alpha, float * B);
-
-//element-wise
-template<typename Op> void gpu_e_f(const int n, const float alpha, float * A);
-template<typename Op> void gpu_e_f(const int n,const float * A,const float 
alpha, const float beta,float * B);
-template<typename Op> void gpu_e_f(const int n,const float * A,const float * 
B,const float alpha, const float beta,float * C);
-// element-wise generalized operation defined in Op
-
-//matrix/vector expand/reduce
-
-template<typename Op> void gpu_reduce_f(const float * A,const int m, const int 
n, float * B);
-//reduce each row of A to an element of B e.g. the sum operation in softmax
-template<typename Op> void gpu_expand_f(const float * A,const int m, const int 
n, float * B);
 //expand each element in A into a row of B
 
+void gpu_gemm(const float * A, const float * B, const int m, const int n, 
const int k, const float alpha, const float beta, const bool TranA, const bool 
TranB, float * C);
+void gpu_gemv(const float * A, const float * B, const int m, const int n, 
const float alpha, const float beta, const bool TranA, float * C);
+void gpu_axpy(const float * A, const int n, const float alpha, float * B);
+
+//element-wise
+template<typename Op> void gpu_e_f(const int n, const float alpha, float * A);
+template<typename Op> void gpu_e_f(const int n,const float * A,const float 
alpha, const float beta,float * B);
+template<typename Op> void gpu_e_f(const int n,const float * A,const float * 
B,const float alpha, const float beta,float * C);
+// element-wise generalized operation defined in Op
+
+//matrix/vector expand/reduce
+
+template<typename Op> void gpu_reduce_f(const float * A,const int m, const int 
n, float * B);
+//reduce each row of A to an element of B e.g. the sum operation in softmax
+template<typename Op> void gpu_expand_f(const float * A,const int m, const int 
n, float * B);
+//expand each element in A into a row of B
+
+
+}  // namespace singa
 #endif // MATH_ADDR_H

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bbb7dbc6/include/singa/blob/math_blob.h
----------------------------------------------------------------------
diff --git a/include/singa/blob/math_blob.h b/include/singa/blob/math_blob.h
index e26e4d4..d5991a7 100644
--- a/include/singa/blob/math_blob.h
+++ b/include/singa/blob/math_blob.h
@@ -1,139 +1,452 @@
+#ifndef MATH_BLOB_H
+#define MATH_BLOB_H
+
+#include <vector>
+#include "singa/utils/blob.h"
+#include "singa/blob/singa_op.h"
+#include "singa/blob/math_addr.h"
+
+
+namespace singa{
 /*********************Level-2 interface, called by user 
code*******************/
 // c++ ususally use const & for input arguments, and * for output arguments.
 // ww: maybe we represent Blob's shape using int s[4]+dim? currently we use a 
vector, which may
 // not be convenient as int array.
 
+
+int get_size(const std::vector<int>& shape);
+
+template <typename Dtype>
+bool check_shape_mv(const Blob<Dtype> & A, const Blob<Dtype> & B)
+{
+       if(A.shape().size() != 2) return false;
+       if(B.shape().size() != 1) return false;
+       if(A.shape().at(0) != B.shape().at(0)) return false;
+       return true;
+}
+
+template <typename Dtype>
+bool check_shape_equal(const Blob<Dtype> & A, const Blob<Dtype> & B, const 
Blob<Dtype> & C)
+{
+  int asize, bsize, csize;
+  asize = get_size(A.shape());
+  bsize = get_size(B.shape());
+  csize = get_size(C.shape());
+  if(asize != bsize) return false;
+  if(asize != csize) return false;
+  return true;
+}
+
+template <typename Dtype>
+bool check_shape_mmm(const Blob<Dtype> & A, const Blob<Dtype> & B, const 
Blob<Dtype> & C)
+{
+  if(A.shape().size() != 2) return false;
+  if(B.shape().size() != 2) return false;
+  if(C.shape().size() != 2) return false;
+  int a1, a2, b1, b2, c1, c2;
+  if(C.isTranspose()) return false;
+  a1 = A.isTranspose() ? A.shape().at(1) : A.shape().at(0);
+  a2 = A.isTranspose() ? A.shape().at(0) : A.shape().at(1);
+  b1 = B.isTranspose() ? B.shape().at(1) : B.shape().at(0);
+  b2 = B.isTranspose() ? B.shape().at(0) : B.shape().at(1);
+  c1 = C.shape().at(0);
+  c2 = C.shape().at(1);
+  if(a2 != b1) return false;
+  if(a1 != c1) return false;
+  if(b2 != c2) return false;
+  return true;
+}
+
+template <typename Dtype>
+bool check_shape_vvm(const Blob<Dtype> & A, const Blob<Dtype> & B, const 
Blob<Dtype> & C)
+{
+  if(A.shape().size() != 1) return false;
+  if(B.shape().size() != 1) return false;
+  if(C.shape().size() != 2) return false;
+  int a1, b1, c1, c2;
+  if(C.isTranspose()) return false;
+  a1 = A.shape().at(0);
+  b1 = B.shape().at(0);
+  c1 = C.shape().at(0);
+  c2 = C.shape().at(1);
+  if(a1 != c2) return false;
+  if(b1 != c1) return false;
+  return true;
+}
+
+template <typename Dtype>
+bool check_shape_mvv(const Blob<Dtype> & A, const Blob<Dtype> & B, const 
Blob<Dtype> & C)
+{
+  if(A.shape().size() != 2) return false;
+  if(B.shape().size() != 1) return false;
+  if(C.shape().size() != 1) return false;
+  int a1, a2, b1, c1;
+  a1 = A.isTranspose() ? A.shape().at(1) : A.shape().at(0);
+  a2 = A.isTranspose() ? A.shape().at(0) : A.shape().at(1);
+  b1 = B.shape().at(0);
+  c1 = C.shape().at(0);
+  if(a2 != b1) return false;
+  if(a1 != c1) return false;
+  return true;
+}
+
+/**********************************************************************************/
+// blob transformation
+
+template <typename Dtype>
+Blob<Dtype>* Reshape(const Blob<Dtype> & A, const std::vector<int>& shape)
+{
+  Blob<Dtype>* res = new Blob<Dtype>();
+  res->Mirror(A);
+  res->Reshape(shape);
+  return res;
+}
+
+// the current reshape in blob.h is: void Reshape(const std::vector<int>& 
shape);
+
+template <typename Dtype>
+Blob<Dtype>* Reshape(const Blob<Dtype> & A, int dim1)
+{
+       std::vector<int> tmpshape;
+       tmpshape.push_back(dim1);
+       return Reshape(A, tmpshape);
+}
+
+template <typename Dtype>
+Blob<Dtype>* Reshape(const Blob<Dtype> & A, int dim1, int dim2)
+{
+       std::vector<int> tmpshape;
+       tmpshape.push_back(dim1);
+       tmpshape.push_back(dim2);;
+       return Reshape(A, tmpshape);
+}
+
+template <typename Dtype>
+Blob<Dtype>* Reshape(const Blob<Dtype> & A, int dim1, int dim2, int dim3)
+{
+       std::vector<int> tmpshape;
+       tmpshape.push_back(dim1);
+       tmpshape.push_back(dim2);
+       tmpshape.push_back(dim3);
+       return Reshape(A, tmpshape);
+}
+
+template <typename Dtype>
+Blob<Dtype>* Reshape(const Blob<Dtype> & A, int dim1, int dim2, int dim3, int 
dim4)
+{
+       std::vector<int> tmpshape;
+       tmpshape.push_back(dim1);
+       tmpshape.push_back(dim2);
+       tmpshape.push_back(dim3);
+       tmpshape.push_back(dim4);
+       return Reshape(A, tmpshape);
+}
+
+template <typename Dtype>
+Blob<Dtype>* Reshape(const Blob<Dtype> & A, int dim1, int dim2, int dim3, int 
dim4, int dim5)
+{
+       std::vector<int> tmpshape;
+       tmpshape.push_back(dim1);
+       tmpshape.push_back(dim2);
+       tmpshape.push_back(dim3);
+       tmpshape.push_back(dim4);
+       tmpshape.push_back(dim5);
+       return Reshape(A, tmpshape);
+}
+
+template <typename Dtype>
+Blob<Dtype>* Transpose(const Blob<Dtype> & A)
+{
+       Blob<Dtype>* res = new Blob<Dtype>();
+       res->Mirror(A);
+       res->setTranspose();
+       return res;
+}
+// return A^T
+
+
 
/**********************************************************************************/
 // class1 matrix operation
 
-void MMDot(const Blob & A, const Blob & B, Blob & C);
+
+void MMDot(XPU xpu, const Blob<float> & A, const Blob<float> & B, Blob<float> 
* C);
 // A,B and C are matrix
 
-void MVDot(const Blob & A, const Blob & B, Blob & C);
+
+void MVDot(XPU xpu, const Blob<float> & A, const Blob<float> & B, Blob<float> 
* C);
 // A is matrix,B and C are vector
 
-void VVDot(const Blob & A, const Blob & B, Blob & C);
+
+void VVDot(XPU xpu, const Blob<float> & A, const Blob<float> & B, Blob<float> 
* C);
 // C is matrix,A and B are vector
 
-float VVdot(const Blob & A, const Blob & B);
+
+float VVdot(XPU xpu, const Blob<float> & A, const Blob<float> & B);
 //A and B are vectors
 
-void GEMM(const Blob & A, const Blob & B, Blob & C, float alpha = 1, float 
beta = 1);
+
+void GEMM(XPU xpu, const Blob<float> & A, const Blob<float> & B, Blob<float> * 
C, float alpha = 1, float beta = 1);
 //C = alpha*A*B+beta*C, A, B and C are matrix
 
-Blob Reshape(const Blob & A, const std::vector<int>& shape);
-// the current reshape in blob.h is: void Reshape(const std::vector<int>& 
shape);
-// return the reference of the reshaped blob
 
-Blob Transpose(const Blob & A);
-// return A^T, only reference to the blob A
-// ww: just add a bool field in Blob, e.g., transpose_
+
 
/**********************************************************************************/
 // class2 element-wise operation
 
-void Set(Blob & A,float alpha);
-// element-wise operation: Ai = alpha
-
-void AXPY(const Blob & A, Blob & B, float alpha);
-// element-wise operation: Bi = alpha*Ai+Bi  A and B should have the same size
+// element-wise generalized operation defined in Op
 
-void Add(const Blob & A, const Blob & B, Blob & C);
-// element-wise operation: Ci = Ai+Bi  A,B and C should have the same size
 
-void Sub(const Blob & A, const Blob & B, Blob & C);
-// element-wise operation: Ci = Ai-Bi  A,B and C should have the same size
+template<typename Op> 
+void E_Func(XPU xpu, Blob<float> * A, float alpha)
+{
+       if(xpu == cpu)
+       {
+               int n = get_size(A->shape());
+               cpu_e_f<Op>(n, alpha, A->mutable_cpu_data());
+       }
+       if(xpu == gpu)
+       {
+               //gpu part
+       }
+}
+
+template<typename Op>
+void E_Func(XPU xpu, const Blob<float> & A, Blob<float> * B, float alpha)
+{
+       if(xpu == cpu)
+       {
+               if(check_shape_equal(A, *B, *B))
+               {
+                       int n = get_size(A.shape());
+                       cpu_e_f<Op>(n, A.cpu_data(), alpha, 
B->mutable_cpu_data());
+               }
+               else{
+                       // report errors here
+               }       
+       }
+       if(xpu == gpu)
+       {
+               //gpu part
+       }
+}
+
+template<typename Op>
+void E_Func(XPU xpu, const Blob<float> & A, const Blob<float> & B, Blob<float> 
* C, float alpha, float beta)
+{
+       if(xpu == cpu)
+       {
+               if(check_shape_equal(A, B, *C))
+               {
+                       int n = get_size(A.shape());
+                       cpu_e_f<Op>(n, A.cpu_data(), B.cpu_data(), alpha, beta, 
C->mutable_cpu_data());
+               }
+               else{
+                       // report errors here
+               }
+       }
+       if(xpu == gpu)
+       {
+               //gpu part
+       }
+}
+
+
+inline void Set(XPU xpu, Blob<float> * A,float alpha)
+{
+       E_Func<singa_op::Set>(xpu, A, alpha);
+}
+// element-wise operation: Ai = alpha
 
-void Mult(const Blob & A, const Blob & B, Blob & C);
-// element-wise operation: Ci = Ai*Bi  A,B and C should have the same size
 
-void Div(const Blob & A, const Blob & B, Blob & C);
-// element-wise operation: Ci = Ai/Bi  A,B and C should have the same size
-
-void Scale(const Blob & A, Blob & B, float alpha);
+inline void Scale(XPU xpu, const Blob<float> & A, Blob<float> * B, float alpha)
+{
+       E_Func<singa_op::Scale>(xpu, A, B, alpha);
+}
 // element-wise operation: Bi = alpha*Ai
 
-void Sigmoid(const Blob & A, Blob & B,float t);
-// element-wise operation: Bi = 1/(1+e^(-Ai*t))
-
-void Relu(const Blob & A, Blob & B,float t = 0);
-// element-wise operation: Bi = ((1-t)abs(Ai) + (1+t)Ai)/2
-
-void Tanh(const Blob & A, Blob & B,float t);
-// element-wise operation: Bi = tanh(Ai*t)
-
-void Exp(const Blob & A, Blob & B, float alpha = 2.71);
+inline void Exp(XPU xpu, const Blob<float> & A, Blob<float> * B, float alpha = 
2.71)
+{
+       E_Func<singa_op::Exp>(xpu, A, B, alpha);
+}
 // element-wise operation: Bi = alpha^Ai
-// ww: there are many element-wise operations, e.g., log, square, sqrt. 
-// If MKL/OpenBLAS or other libraries do not optimize these operations in 
ad-hoc manner, 
-// then we can implement then using the E_Func below by passing the basic 
log/square/sqrt operations.
 
+inline void Exp_grad(XPU xpu, const Blob<float> & A, Blob<float> * B, float 
alpha = 2.71)
+{
+       E_Func<singa_op::Exp_grad>(xpu, A, B, alpha);
+}
+// element-wise operation: Bi = Ai*log(alpha)
+
+inline void Gsigmoid(XPU xpu, const Blob<float> & A, Blob<float> * B,float 
alpha)
+{
+       E_Func<singa_op::Gsigmoid>(xpu, A, B, alpha);
+}
+// element-wise operation: b = 1.0f / (1.0f + expf(-a * alpha));
+
+inline void Gsigmoid_grad(XPU xpu, const Blob<float> & A, Blob<float> * 
B,float alpha)
+{
+       E_Func<singa_op::Gsigmoid_grad>(xpu, A, B, alpha);
+}
+// element-wise operation: b = alpha * a * ( 1.0f - a );
+
+inline void Grelu(XPU xpu, const Blob<float> & A, Blob<float> * B,float alpha 
= 0)
+{
+       E_Func<singa_op::Grelu>(xpu, A, B, alpha);
+}
+// element-wise operation: b = ( 1 - alpha ) * std::max( a, 0.0f ) + alpha * a;
+
+inline void Grelu_grad(XPU xpu, const Blob<float> & A, Blob<float> * B,float 
alpha = 0)
+{
+       E_Func<singa_op::Grelu_grad>(xpu, A, B, alpha);
+}
+// element-wise operation: b = a > 0.0f ? 1.0f : alpha;
+
+inline void Gtanh(XPU xpu, const Blob<float> & A, Blob<float> * B,float alpha)
+{
+       E_Func<singa_op::Gtanh>(xpu, A, B, alpha);
+}
+// element-wise operation: b = tanhf( a * alpha );
+
+inline void Gtanh_grad(XPU xpu, const Blob<float> & A, Blob<float> * B,float 
alpha)
+{
+       E_Func<singa_op::Gtanh_grad>(xpu, A, B, alpha);
+}
+// element-wise operation: b = alpha * ( 1.0f - a * a );
+        
+inline void Softplus(XPU xpu, const Blob<float> & A, Blob<float> * B)
+{
+       E_Func<singa_op::Softplus>(xpu, A, B, 0);
+}
+// element-wise operation: b = logf(1 + expf(a));
+
+inline void Softplus_grad(XPU xpu, const Blob<float> & A, Blob<float> * B)
+{
+       E_Func<singa_op::Softplus_grad>(xpu, A, B, 0);
+}
+// element-wise operation: b = 1.0f / (1.0f + expf(-a));
+
+inline void Square(XPU xpu, const Blob<float> & A, Blob<float> * B)
+{
+       E_Func<singa_op::Square>(xpu, A, B, 0);
+}
+// element-wise operation: b = a * a;
+
+inline void Square_grad(XPU xpu, const Blob<float> & A, Blob<float> * B)
+{
+       E_Func<singa_op::Square_grad>(xpu, A, B, 0);
+}
+// element-wise operation: b = 2 * sqrt(a);
+
+inline void Sqrt(XPU xpu, const Blob<float> & A, Blob<float> * B)
+{
+       E_Func<singa_op::Sqrt>(xpu, A, B, 0);
+}
+// element-wise operation: b = sqrt(a);
+
+inline void Threshold(XPU xpu, const Blob<float> & A, float alpha, Blob<float> 
* B)
+{
+       E_Func<singa_op::Threshold>(xpu, A, B, alpha);
+}
+// element-wise operation: b =  a < alpha ? 1.0f : 0.0f;
+
+inline void Add(XPU xpu, const Blob<float> & A, const Blob<float> & B, 
Blob<float> * C)
+{
+       E_Func<singa_op::Add>(xpu, A, B, C, 0, 0);
+}
+// element-wise operation: Ci = Ai+Bi  A,B and C should have the same size
 
+inline void Sub(XPU xpu, const Blob<float> & A, const Blob<float> & B, 
Blob<float> * C)
+{
+       E_Func<singa_op::Sub>(xpu, A, B, C, 0, 0);
+}
+// element-wise operation: Ci = Ai-Bi  A,B and C should have the same size
 
-template<typename Op> void E_Func(Blob & A, float alpha);
-template<typename Op> void E_Func(const Blob & A, Blob & B, float alpha, float 
beta);
-template<typename Op> void E_Func(const Blob & A, const Blob & B, Blob & C, 
float alpha, float beta);
-// element-wise generalized operation defined in Op
+inline void Mult(XPU xpu, const Blob<float> & A, const Blob<float> & B, 
Blob<float> * C)
+{
+       E_Func<singa_op::Mult>(xpu, A, B, C, 0, 0);
+}
+// element-wise operation: Ci = Ai*Bi  A,B and C should have the same size
+
+inline void Div(XPU xpu, const Blob<float> & A, const Blob<float> & B, 
Blob<float> * C)
+{
+       E_Func<singa_op::Div>(xpu, A, B, C, 0, 0);
+}
+// element-wise operation: Ci = Ai/Bi  A,B and C should have the same size
 
 
-// ww: the following functions may require thread specific variables, e.g., 
seed or random stream state.
+void AXPY(XPU xpu, const Blob<float> & A, Blob<float> * B, float alpha);
+// element-wise operation: Bi = alpha*Ai+Bi  A and B should have the same size
 
-void Gaussian(Blob & A, float mu, float sigma);
+//todo: random part
+/*
+void Gaussian(XPU xpu, Blob & A, float mu, float sigma);
 // element-wise operation: initialize each element in A following distribution 
Gaussian(mu, sigma)
 
-void Uniform(Blob & A, float low, float high);
+void Uniform(XPU xpu, Blob & A, float low, float high);
 // element-wise operation: initialize each element in A following uniform 
distribution from low to high
 
-void Bernoulli(Blob & A, float p, int n = 1);
+void Bernoulli(XPU xpu, Blob & A, float p, int n = 1);
 // element-wise operation: initialize each element in A following distribution 
Bernoulli(n,p)
-
+*/
 
 
/**********************************************************************************/
 //class3 matrix-vector expand/reduce operation
 
-template<typename Op> void Reduce_F(const Blob & A, Blob & B);
+template<typename Op> 
+void Reduce_F(XPU xpu, const Blob<float> & A, Blob<float> * B)
+{
+       if(xpu == cpu)
+       {
+               if(check_shape_mv(A, *B))
+               {
+                       int m = get_size(B->shape());
+                       int n = get_size(A.shape()) / m;
+                       cpu_reduce_f<Op>(A.cpu_data(), m, n, 
B->mutable_cpu_data());
+               }
+               else{
+                       // report errors here
+               }
+       }
+       if(xpu == gpu)
+       {
+               //gpu part
+       }
+}
 //reduce each row of A to an element of B e.g. the sum operation in softmax
-template<typename Op> void Expand_F(const Blob & A, Blob & B);
+
+template<typename Op> 
+void Expand_F(XPU xpu, const Blob<float> & A, Blob<float> * B)
+{
+       if(xpu == cpu)
+       {
+               if(check_shape_mv(*B, A))
+               {
+                       int m = get_size(A.shape());
+                       int n = get_size(B->shape()) / m;
+                       cpu_expand_f<Op>(A.cpu_data(), m, n, 
B->mutable_cpu_data());
+               }
+               else{
+                       // report errors here
+               }
+       }
+       if(xpu == gpu)
+       {
+               //gpu part
+       }
+}
 //expand each element in A into a row of B
 
-void Repmat(const Blob & A, Blob & B);
+void Repmat(XPU xpu, const Blob<float> & A, Blob<float> * B);
 // A is a vector, B is a matrix , let each row of B to be A
-// just copy memory, will be faster
 
-// ww may rename to MVAdd, MVSum to be consistent with the MVDot, MMDot, VVDot.
-void MVAdd(const Blob & A, Blob & B, float alpha, float beta);
+void MVAdd(XPU xpu, const Blob<float> & A, Blob<float> * B, float alpha, float 
beta);
 // A is a vector, B is a matrix , Bij = alpha*Ai+beta*Bij
 // will use gemm. faster than general expand_f
 
-void MVSum(const Blob & A, Blob & B, float alpha, float beta);
+void MVSum(XPU xpu, const Blob<float> & A, Blob<float> * B, float alpha, float 
beta);
 // A is a vector, B is a matrix , Ai = \sigma_j_{alpha*Bij}+beta*Ai
 // will use gemm. faster than general reduce_f
 
-void Softmax(const Blob & A,Blob & B,float alpha);
-// Bij = e^(alpha*Aij) / \sigma_i_{e^(alpha*Aij)}
 
-/**********************************************************************************/
-//class4 convolution operation
-
-void Conv(const Blob & A,const Blob & B,Blob & C);
-// A is the data, B is the parameter, C is the result
-
-void Pool(const Blob & A,Blob & B, int method);
-// A is the data, B is the result, should indicate max or ave pooling
-
-// jy: need to provide grad compute function respectively?
-
-// ww: The conv/pool operations cannot be declared as above, 
-// because they require other parameters, e.g., filter size, num of filters, 
pad, stride, etc.
-// Caffe and mxnet use cudnn in layer implementations instead of implementing 
low-level operations.
-// Maybe we can follow caffe? layer implementation is not our contribution, we 
can use others' code directly.
-// For CPU version, we use im2col and col2im for the conv layer. 
-// For GPU version, we use cudnn for the conv layer. Similarly for other 
layers. 
-// In conclude, we may not implement Blob level Conv and Pool operations. 
-// Instead, we implement CaffeConvLayer, CaffePoolLayer, cudnnConvLayer, 
cudnnPoolLayer.
-// Later we may add IntelConvLayer (cpu), NeonConvLayer (gpu).
-
-Blob setcolspace(const Blob & A);
-void im2col(const Blob & A,Blob & B);
-void col2im(const Blob & A,Blob & B);
-//given an img, use setcolspace to generate colspace Blob
-//use pack/unpack to get data in col/img
+} // end of namespace singa
+
+#endif // MATH_BLOB_H

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bbb7dbc6/include/singa/blob/singa_op.h
----------------------------------------------------------------------
diff --git a/include/singa/blob/singa_op.h b/include/singa/blob/singa_op.h
index 7ca06db..b36c001 100644
--- a/include/singa/blob/singa_op.h
+++ b/include/singa/blob/singa_op.h
@@ -4,7 +4,12 @@
 #include<cmath>
 #include <algorithm>
 
-namespace op {
+namespace singa {
+       enum XPU { cpu, gpu, any};
+
+}
+
+namespace singa_op {
         struct Set {
             inline static void Map(float alpha, float & a) {
                 a= alpha;
@@ -97,8 +102,8 @@ namespace op {
         };
 
         struct Threshold {
-            inline static void Map(float alpha, float beta, const float & a, 
const float & b, float & c) {
-                c =  a < b ? 1.0f : 0.0f;
+            inline static void Map(float alpha, const float & a, float & b) {
+                b =  a < alpha ? 1.0f : 0.0f;
             }
         };
 
@@ -134,7 +139,16 @@ namespace op {
                             b += a[i];
                 }
             }
-            };
+        };
+
+        struct Expand_Div {
+            inline static void Map(const float & a, int n, float * b) {
+                for(int i = 0 ; i < n ; i++)
+                {
+                            b[i] /= a;
+                }
+            }
+        };
 
         struct Repmat {
             inline static void Map(const float & a, int n, float * b) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bbb7dbc6/include/singa/blob/test.cc
----------------------------------------------------------------------
diff --git a/include/singa/blob/test.cc b/include/singa/blob/test.cc
deleted file mode 100644
index d13ed5e..0000000
--- a/include/singa/blob/test.cc
+++ /dev/null
@@ -1,165 +0,0 @@
-#include <iostream>
-
-#include "singa_op.h"
-#include "math_addr.h"
-
-using namespace std;
-
-void test_gemm1()
-{
-            float A[3][2] = {};
-            float B[3][2] = {};
-            float C[2][2] = {};
-            for(int i = 0; i < 3; i++)
-                for(int j = 0; j < 2; j++)
-                {
-                A[i][j] = i+j;
-                B[i][j] = i+j - i*j;
-                }
-            cpu_gemm(A[0], B[0], 2, 2, 3 , 1, 0, true, false, C[0]);
-            float D[2][2] = {};
-            for(int i = 0; i < 2; i++)
-                for(int j = 0; j < 2; j++)
-                {
-                    D[i][j] = 0;
-                    for(int k = 0; k < 3; k++)
-                    D[i][j] += A[k][i]*B[k][j];
-                }
-            for(int i = 0; i < 2; i++)
-                for(int j = 0; j < 2; j++)
-                {
-                cout<<C[i][j] - D[i][j]<<endl;
-                }
-}
-
-
-void test_gemm2()
-{
-            float A[2][3] = {};
-            float B[3][2] = {};
-            float C[2][2] = {};
-            for(int i = 0; i < 3; i++)
-                for(int j = 0; j < 2; j++)
-                {
-                A[j][i] = i-j;
-                B[i][j] = i+j + i*j;
-                }
-            cpu_gemm(A[0], B[0], 2, 2, 3 , 1, 0, false, false, C[0]);
-            float D[2][2] = {};
-            for(int i = 0; i < 2; i++)
-                for(int j = 0; j < 2; j++)
-                {
-                    D[i][j] = 0;
-                    for(int k = 0; k < 3; k++)
-                    D[i][j] += A[i][k]*B[k][j];
-                }
-            for(int i = 0; i < 2; i++)
-                for(int j = 0; j < 2; j++)
-                {
-                cout<<C[i][j] - D[i][j]<<endl;
-                }
-}
-
-
-void test_gemv()
-{
-        float A[4][3] = {};
-        float B[4]= {};
-        float C[3] = {};
-        float D[3] = {};
-        for(int i = 0; i < 4; i++)
-        {
-            for(int j = 0; j < 3; j++)
-                    {
-                    A[j][i] = i-j + i*j;
-                    }
-        }
-        for(int i = 0; i < 4; i++)B[i] = i;
-        for(int i = 0; i < 3; i++)C[i] = 10;
-        cpu_gemv(A[0], B, 4, 3, 1, 1, true, C);
-        for(int i = 0; i < 3; i++)
-                for(int j = 0; j < 4; j++)
-                {
-                    D[i] += A[j][i]*B[j];
-                }
-        for(int i = 0; i < 3; i++)cout<<C[i] - D[i] - 10<<endl;
-}
-
-void test_axpy()
-{
-        float A[4][3] = {};
-        float C[4][3] = {};
-        float B[3][4] = {};
-        float D[3][4] = {};
-        for(int i = 0; i < 4; i++)
-        {
-            for(int j = 0; j < 3; j++)
-                    {
-                    A[i][j] = i-j + i*j;
-                    B[j][i] = i-j + i*j;
-                    C[i][j] = A[i][j];
-                    D[j][i] = B[j][i];
-                    }
-        }
-        cpu_axpy(A[0], 12, 2, B[0]);
-        for(int i = 0; i < 12; i++)D[0][i] += 2*C[0][i];
-        for(int i = 0; i < 3; i++)
-        {
-            for(int j = 0; j < 4; j++)
-                    {
-                    cout<<B[i][j] - D[i][j]<<endl;
-                    }
-        }
-}
-
-void test_eop()
-{
-        float A[10] = {};
-        float B[10] = {};
-        float C[10] = {};
-        float D[10] = {};
-        float O[10] = {};
-        for(int i = 0; i < 10; i++)
-        {
-            A[i] = i;
-            B[i] = -i;
-            C[i] = i;
-        }
-        cpu_e_f<op::Set>(5, 15, O);
-        for(int i = 0; i < 5; i++)cout<<O[i] - 15<<endl;
-        for(int i = 5; i < 10; i++)cout<<O[i]<<endl;
-        cpu_e_f<op::Scale>(10, C, 2, C);
-        for(int i = 0; i < 10; i++)cout<<C[i] - 2* i<<endl;
-        cpu_e_f<op::Add>(10, A, B, 0, 0, O);
-        for(int i = 0; i < 10; i++)cout<<O[i]<<endl;
-}
-
-void test_exrd()
-{
-        float A[3][10] = {};
-        float B[3] = {};
-        for(int i = 0; i < 3; i++)
-            for(int j = 0; j < 10; j++)
-            {
-                A[i][j] = (i + 1)*j;
-            }
-        cpu_reduce_f<op::Sum>(A[0], 3, 10, B);
-        for(int i = 0; i < 3; i++) B[i] -= 45*(i+1);
-        for(int i = 0; i < 3; i++)cout<<B[i]<<endl;
-        cpu_expand_f<op::Repmat>(B, 3, 10, A[0]);
-        cpu_reduce_f<op::Sum>(A[0], 3, 10, B);
-        for(int i = 0; i < 3; i++)cout<<B[i]<<endl;
-}
-
-int main()
-{
-    test_gemm1()  ;
-       test_gemm2();
-       test_gemv();
-       test_axpy();
-       test_eop();
-       test_exrd();
-    return 0;
-}
-
-

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bbb7dbc6/include/singa/utils/blob.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/blob.h b/include/singa/utils/blob.h
index b6bcc3e..0ebf8fd 100644
--- a/include/singa/utils/blob.h
+++ b/include/singa/utils/blob.h
@@ -186,6 +186,21 @@ class Blob {
   /// @brief Compute the sum of absolute values (L1 norm) of the data.
   Dtype asum_data() const;
   Dtype sum_data() const;
+  inline void setTranspose() {
+    isTranspose_ = !isTranspose_;
+  }
+  inline bool isTranspose() const {
+    return isTranspose_;
+  }
+  inline void Mirror(const Blob<Dtype> & other) {
+    data_ = other.data_;
+    shape_ = other.shape_;
+    count_ = other.count_;
+    capacity_ = other.capacity_;
+    version_ = other.version_;
+    isTranspose_ = other.isTranspose_;
+  }
+
 
  protected:
   std::shared_ptr<SyncedMemory> data_ = nullptr;
@@ -193,6 +208,7 @@ class Blob {
   int count_ = 0;
   int capacity_ = 0;
   int version_ = -1;
+  bool isTranspose_ = false;
 };  // class Blob
 
 using namespace mshadow;

Reply via email to