SINGA-80 New Blob Level and Address Level Math Operation Interface

----

add SINGA_GPU macro to complie without gpu functions


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/641eb317
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/641eb317
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/641eb317

Branch: refs/heads/master
Commit: 641eb317ecce8b793f87552840d55e4f8da54d39
Parents: 01d91af
Author: jinyangturbo <[email protected]>
Authored: Fri Nov 6 05:15:46 2015 -0800
Committer: Wei Wang <[email protected]>
Committed: Mon Nov 9 17:04:48 2015 +0800

----------------------------------------------------------------------
 include/singa/blob/math_addr.h |  2 ++
 include/singa/blob/math_blob.h | 12 +++++++-
 include/singa/blob/singa_op.h  | 56 +++++++++++++++++++++++++++++++++++--
 src/blob/math_addr.cc          | 11 +++++---
 src/blob/math_blob.cc          | 16 +++++++++++
 5 files changed, 89 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/641eb317/include/singa/blob/math_addr.h
----------------------------------------------------------------------
diff --git a/include/singa/blob/math_addr.h b/include/singa/blob/math_addr.h
index 7c74201..2a25a29 100644
--- a/include/singa/blob/math_addr.h
+++ b/include/singa/blob/math_addr.h
@@ -81,6 +81,7 @@ void cpu_expand_f(const float * A, const int m, const int n, 
float * B) {
 }
 // expand each element in A into a row of B
 
+#ifdef SINGA_GPU
 void gpu_gemm(const float * A, const float * B,
 const int m, const int n, const int k, const float alpha, const float beta,
 const bool TranA, const bool TranB, float * C);
@@ -126,6 +127,7 @@ void gpu_expand_f(const float * A, const int m, const int 
n, float * B) {
                 }
 }
 // expand each element in A into a row of B
+#endif  // SINGA_GPU  
 
 }  // namespace singa
 #endif  // SINGA_BLOB_MATH_ADDR_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/641eb317/include/singa/blob/math_blob.h
----------------------------------------------------------------------
diff --git a/include/singa/blob/math_blob.h b/include/singa/blob/math_blob.h
index b52cb91..638f9cc 100644
--- a/include/singa/blob/math_blob.h
+++ b/include/singa/blob/math_blob.h
@@ -24,7 +24,7 @@
 
 #include <vector>
 #include "singa/utils/blob.h"
-#include "singa/blob/singa::op.h"
+#include "singa/blob/singa_op.h"
 #include "singa/blob/math_addr.h"
 
 
@@ -218,11 +218,13 @@ void E_Func(XPU xpu, Blob<float> * A, float alpha) {
         int n = get_size(A->shape());
         cpu_e_f<Op>(n, alpha, A->mutable_cpu_data());
     }
+    #ifdef SINGA_GPU
     if (xpu == gpu) {
         // gpu part
         int n = get_size(A->shape());
         gpu_e_f<Op>(n, alpha, A->mutable_gpu_data());
     }
+    #endif  // SINGA_GPU
 }
 
 template<typename Op>
@@ -232,10 +234,12 @@ void E_Func(XPU xpu, const Blob<float> & A, Blob<float> * 
B, float alpha) {
         if (xpu == cpu) {
             cpu_e_f<Op>(n, A.cpu_data(), alpha, B->mutable_cpu_data());
         }
+        #ifdef SINGA_GPU
         if (xpu == gpu) {
             // gpu part
             gpu_e_f<Op>(n, A.gpu_data(), alpha, B->mutable_gpu_data());
         }
+        #endif  // SINGA_GPU
     } else {
         // report errors here
     }
@@ -250,11 +254,13 @@ Blob<float> * C, float alpha, float beta) {
             cpu_e_f<Op>(n, A.cpu_data(), B.cpu_data(), alpha, beta,
             C->mutable_cpu_data());
         }
+        #ifdef SINGA_GPU
         if (xpu == gpu) {
             // gpu part
             gpu_e_f<Op>(n, A.gpu_data(), B.gpu_data(), alpha, beta,
             C->mutable_gpu_data());
         }
+        #endif  // SINGA_GPU
     } else {
         // report errors here
     }
@@ -391,10 +397,12 @@ void Reduce_F(XPU xpu, const Blob<float> & A, Blob<float> 
* B) {
         if (xpu == cpu) {
             cpu_reduce_f<Op>(A.cpu_data(), m, n, B->mutable_cpu_data());
         }
+        #ifdef SINGA_GPU
         if (xpu == gpu) {
             // gpu part
             gpu_reduce_f<Op>(A.gpu_data(), m, n, B->mutable_gpu_data());
         }
+        #endif  // SINGA_GPU
     } else {
         // report errors here
     }
@@ -409,10 +417,12 @@ void Expand_F(XPU xpu, const Blob<float> & A, Blob<float> 
* B) {
         if (xpu == cpu) {
             cpu_expand_f<Op>(A.cpu_data(), m, n, B->mutable_cpu_data());
         }
+        #ifdef SINGA_GPU
         if (xpu == gpu) {
             // gpu part
             gpu_expand_f<Op>(A.gpu_data(), m, n, B->mutable_gpu_data());
         }
+        #endif  // SINGA_GPU  
     } else {
         // report errors here
     }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/641eb317/include/singa/blob/singa_op.h
----------------------------------------------------------------------
diff --git a/include/singa/blob/singa_op.h b/include/singa/blob/singa_op.h
index abdfd66..3747568 100644
--- a/include/singa/blob/singa_op.h
+++ b/include/singa/blob/singa_op.h
@@ -22,12 +22,15 @@
 #ifndef SINGA_BLOB_SINGA_OP_H_
 #define SINGA_BLOB_SINGA_OP_H_
 
+#ifdef SINGA_GPU
 #include <cuda_runtime.h>
+#endif  // SINGA_GPU
 #include <cmath>
 #include <algorithm>
-// #include "cublas_v2.h"
+#ifdef SINGA_GPU
+#include "cublas_v2.h"
 #include "singa/blob/math_kernel.h"
-
+#endif  // SINGA_GPU
 
 namespace singa {
     enum XPU { cpu, gpu, any};
@@ -37,38 +40,46 @@ struct Set {
     inline static void Map(float alpha, float * a) {
         *a = alpha;
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha, float * a, int n) {
         singa::singa_gpu_set_value(a, alpha, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Scale {
     inline static void Map(float alpha,  const float & a, float * b) {
         *b = alpha * a;
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha,  const float * a,
     float * b, int n) {
         singa::singa_gpu_scale(a, b, alpha, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Scale_grad {
     inline static void Map(float alpha,  float * output) {
         *output = alpha;
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha,  float * output, int n) {
         singa::singa_gpu_scale_grad(output, alpha, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Exp {
     inline static void Map(float alpha,  const float & a, float * b) {
         *b = pow(a, alpha);
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha,  const float * a,
     float * b, int n) {
         singa::singa_gpu_exp(a, b, alpha, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Exp_grad {
@@ -76,130 +87,156 @@ struct Exp_grad {
         // log is the natrual log based on e
         *b = a * log(alpha);
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha,  const float * a,
     float * b, int n) {
         singa::singa_gpu_exp_grad(a, b, alpha, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Gsigmoid {
     inline static void Map(float alpha,  const float & a, float * b) {
         *b = 1.0f / (1.0f + expf(-a * alpha));
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha,  const float * a,
     float * b, int n) {
         singa::singa_gpu_sigmoid(a, b, alpha, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Gsigmoid_grad {
     inline static void Map(float alpha,  const float & a, float * b) {
         *b = alpha * a * (1.0f - a);
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha,  const float * a,
     float * b, int n) {
         singa::singa_gpu_sigmoid_grad(a, b, alpha, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Grelu {
     inline static void Map(float alpha,  const float & a, float * b) {
         *b = (1 - alpha) * std::max(a, 0.0f) + alpha * a;
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha,  const float * a,
     float * b, int n) {
         singa::singa_gpu_relu(a, b, alpha, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Grelu_grad {
     inline static void Map(float alpha,  const float & a, float * b) {
         *b = a > 0.0f ? 1.0f : alpha;
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha,  const float * a,
     float * b, int n) {
         singa::singa_gpu_relu_grad(a, b, alpha, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Gtanh {
     inline static void Map(float alpha,  const float & a, float * b) {
         *b = tanhf(a * alpha);
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha,  const float * a,
     float * b, int n) {
         singa::singa_gpu_tanh(a, b, alpha, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Gtanh_grad {
     inline static void Map(float alpha,  const float & a, float * b) {
         *b = alpha * (1.0f - a * a);
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha,  const float * a,
     float * b, int n) {
         singa::singa_gpu_tanh_grad(a, b, alpha, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Softplus {
     inline static void Map(float alpha,  const float & a, float * b) {
         *b = logf(1 + expf(a));
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha,  const float * a,
     float * b, int n) {
         singa::singa_gpu_softplus(a, b, alpha, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Softplus_grad {
     inline static void Map(float alpha,  const float & a, float * b) {
         *b = 1.0f / (1.0f + expf(-a));
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha,  const float * a,
     float * b, int n) {
         singa::singa_gpu_softplus_grad(a, b, alpha, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Square {
     inline static void Map(float alpha,  const float & a, float * b) {
         *b = a * a;
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha,  const float * a,
     float * b, int n) {
         singa::singa_gpu_square(a, b, alpha, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Square_grad {
     inline static void Map(float alpha,  const float & a, float * b) {
         *b = 2 * sqrt(a);
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha,  const float * a,
     float * b, int n) {
         singa::singa_gpu_square_grad(a, b, alpha, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Sqrt {
     inline static void Map(float alpha,  const float & a, float * b) {
         *b = sqrt(a);
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha,  const float * a,
     float * b, int n) {
         singa::singa_gpu_sqrt(a, b, alpha, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Threshold {
     inline static void Map(float alpha, const float & a, float * b) {
         *b =  a < alpha ? 1.0f : 0.0f;
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha,  const float * a,
     float * b, int n) {
         singa::singa_gpu_threshold(a, b, alpha, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Add {
@@ -207,10 +244,12 @@ struct Add {
     const float & b, float * c) {
         *c =  a + b;
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha, float beta, const float * a,
     const float * b, float * c, int n) {
         singa::singa_gpu_add(a, b, c, alpha, beta, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Sub {
@@ -218,10 +257,12 @@ struct Sub {
     const float & b, float * c) {
         *c =  a - b;
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha, float beta, const float * a,
     const float * b, float * c, int n) {
         singa::singa_gpu_sub(a, b, c, alpha, beta, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Mult {
@@ -229,10 +270,12 @@ struct Mult {
     const float & b, float * c) {
         *c =  a * b;
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha, float beta, const float * a,
     const float * b, float * c, int n) {
         singa::singa_gpu_mult(a, b, c, alpha, beta, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Div {
@@ -240,10 +283,12 @@ struct Div {
     const float & b, float * c) {
         *c =  a / b;
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(float alpha, float beta, const float * a,
     const float * b, float * c, int n) {
         singa::singa_gpu_div(a, b, c, alpha, beta, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Sum {
@@ -253,7 +298,7 @@ struct Sum {
                     *b += a[i];
         }
     }
-
+    #ifdef SINGA_GPU
     inline static void CudaMap(const float * a, int n, float * b) {
         float *sum = NULL;
         cudaMalloc(<void**>(&sum), n*sizeof(float));
@@ -263,6 +308,7 @@ struct Sum {
         cudaMemcpyAsync(b, sum, sizeof(float), cudaMemcpyDeviceToDevice);
         cudaFree(sum);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Expand_Div {
@@ -271,9 +317,11 @@ struct Expand_Div {
                     b[i] /= a;
         }
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(const float & a, int n, float * b) {
         singa::singa_gpu_scale(b, b, a, n);
     }
+    #endif  // SINGA_GPU
 };
 
 struct Repmat {
@@ -282,9 +330,11 @@ struct Repmat {
                     b[i] = a;
         }
     }
+    #ifdef SINGA_GPU
     inline static void CudaMap(const float & a, int n, float * b) {
         singa::singa_gpu_set_value(b, a, n);
     }
+    #endif  // SINGA_GPU
 };
 
 };  // namespace op

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/641eb317/src/blob/math_addr.cc
----------------------------------------------------------------------
diff --git a/src/blob/math_addr.cc b/src/blob/math_addr.cc
index 8451957..fb1c42e 100644
--- a/src/blob/math_addr.cc
+++ b/src/blob/math_addr.cc
@@ -23,11 +23,13 @@
 extern "C" {
     #include <cblas.h>
 }
+#ifdef SINGA_GPU
 #include <cuda_runtime.h>
+#endif
 #include "singa/blob/singa_op.h"
-// #include "cublas_v2.h"
-
-
+#ifdef SINGA_GPU
+#include "cublas_v2.h"
+#endif
 
 namespace singa {
 
@@ -69,8 +71,8 @@ float cpu_dot(const float * A, const float * B, const int n) {
     return sum;
 }
 
+#ifdef SINGA_GPU
 // Trick: swap A and B
-//
 void gpu_gemm(const float * A, const float * B, const int m, const int n,
 const int k, const float alpha, const float beta, const bool TranA,
 const bool TranB, float * C) {
@@ -113,5 +115,6 @@ float gpu_dot(const float * A, const float * B, const int 
n) {
     cublasDestroy(handle);
     return result;
 }
+#endif
 
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/641eb317/src/blob/math_blob.cc
----------------------------------------------------------------------
diff --git a/src/blob/math_blob.cc b/src/blob/math_blob.cc
index ad0b766..083d3e5 100644
--- a/src/blob/math_blob.cc
+++ b/src/blob/math_blob.cc
@@ -20,7 +20,9 @@
 *************************************************************/
 
 #include "singa/blob/math_blob.h"
+#ifdef SINGA_GPU
 #include "singa/blob/math_kernel.h"
+#endif  // SINGA_GPU
 
 namespace singa {
 
@@ -49,11 +51,13 @@ Blob<float> * C, float alpha, float beta) {
             cpu_gemm(A.cpu_data(), B.cpu_data(), m, n, k, alpha, beta,
             TranA, TranB, C->mutable_cpu_data());
         }
+        #ifdef SINGA_GPU
         if (xpu == gpu) {
             // gpu part
             gpu_gemm(A.gpu_data(), B.gpu_data(), m, n, k, alpha, beta,
             TranA, TranB, C->mutable_gpu_data());
         }
+        #endif  // SINGA_GPU
     } else {
         // report errors here
     }
@@ -77,11 +81,13 @@ Blob<float> * C) {
             cpu_gemv(A.cpu_data(), B.cpu_data(), m, n, 1, 0, TranA,
             C->mutable_cpu_data());
         }
+        #ifdef SINGA_GPU
         if (xpu == gpu) {
             // gpu part
             gpu_gemv(A.gpu_data(), B.gpu_data(), m, n, 1, 0, TranA,
             C->mutable_gpu_data());
         }
+        #endif  // SINGA_GPU
     } else {
         // report errors here
     }
@@ -98,11 +104,13 @@ Blob<float> * C) {
             cpu_gemm(A.cpu_data(), B.cpu_data(), m, n, 1, 1, 0,
             false, false, C->mutable_cpu_data());
         }
+        #ifdef SINGA_GPU
         if (xpu == gpu) {
             // gpu part
             gpu_gemm(A.gpu_data(), B.gpu_data(), m, n, 1, 1, 0,
             false, false, C->mutable_gpu_data());
         }
+        #endif  // SINGA_GPU
     } else {
         // report errors here
     }
@@ -117,10 +125,12 @@ float VVdot(XPU xpu, const Blob<float> & A, const 
Blob<float> & B) {
         if (xpu == cpu) {
             res = cpu_dot(A.cpu_data(), B.cpu_data(), n);
         }
+        #ifdef SINGA_GPU
         if (xpu == gpu) {
             // gpu part
             res = gpu_dot(A.gpu_data(), B.gpu_data(), n);
         }
+        #endif  // SINGA_GPU
     } else {
         // report errors here
     }
@@ -134,10 +144,12 @@ void AXPY(XPU xpu, const Blob<float> & A, Blob<float> * 
B, float alpha) {
             cpu_axpy(A.cpu_data(), get_size(A.shape()),
             alpha, B->mutable_cpu_data());
         }
+        #ifdef SINGA_GPU
         if (xpu == gpu) {
             gpu_axpy(A.gpu_data(), get_size(A.shape()),
             alpha, B->mutable_gpu_data());
         }
+        #endif  // SINGA_GPU
     } else {
         // report errors here
     }
@@ -160,11 +172,13 @@ float alpha, float beta) {
             false, false, B->mutable_cpu_data());
             delete univ;
         }
+        #ifdef SINGA_GPU
         if (xpu == gpu) {
             singa_gpu_add_vec_row(B->gpu_data(),
             A.gpu_data(), A.gpu_data(), m, n, n);
             // gpu part
         }
+        #endif  // SINGA_GPU
     } else {
         // report errors here
     }
@@ -183,10 +197,12 @@ float alpha, float beta) {
             false, false, B->mutable_cpu_data());
             delete univ;
         }
+        #ifdef SINGA_GPU
         if (xpu == gpu) {
             singa_gpu_sum_col(A.gpu_data(), B->gpu_data(), m, n, n);
             // gpu part
         }
+        #endif  // SINGA_GPU
     } else {
         // report errors here
     }

Reply via email to