Repository: incubator-singa
Updated Branches:
  refs/heads/dev 7333517b4 -> 790b7b4cd


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/790b7b4c/src/python/swig/model_layer.i
----------------------------------------------------------------------
diff --git a/src/python/swig/model_layer.i b/src/python/swig/model_layer.i
index ee7c319..873ebc9 100644
--- a/src/python/swig/model_layer.i
+++ b/src/python/swig/model_layer.i
@@ -25,6 +25,8 @@
 %include "std_vector.i"
 %include "std_string.i"
 %include "std_pair.i"
+%include "std_shared_ptr.i"
+
 
 %{
 #include "singa/model/layer.h"
@@ -37,13 +39,14 @@ using singa::Device;
 using singa::LayerConf;
 %}
 
+%shared_ptr(singa::Layer)
+
 namespace std {
   %template(strVector) vector<string>;
-  %template(paramVector) vector<ParamSpec>;
-  %template(tensorVector) vector<Tensor>;
-  %template(tensorPtrVector) vector<Tensor*>;
-  %template(ttvecPair) pair<Tensor, vector<Tensor>>;
-  %template(tvecPair) pair<vector<Tensor>, vector<Tensor>>;
+  %template(paramVector) vector<singa::ParamSpec>;
+  %template(tensorVector) vector<singa::Tensor>;
+  %template(ttvecPair) pair<singa::Tensor, vector<singa::Tensor>>;
+  %template(tvecPair) pair<vector<singa::Tensor>, vector<singa::Tensor>>;
 }
 
 
@@ -52,36 +55,23 @@ namespace singa {
   class Layer {
     public:
       Layer();
-      void Setup(const std::vector<size_t>&, const string&);
-      void Setup(const std::vector<vector<size_t>>&, const string&);
-
-      std::string ToProtoStr() const;
-      const std::vector<ParamSpec> param_specs();
-      const ParamSpec& param_specs(size_t i);
-      const std::vector<Tensor*> param_values();
-      Tensor* param_value(size_t i);
-      const std::vector<std::string> param_names();
-      const std::string& param_name(size_t i);
-      const std::string name() const;
-
-      /* virtual functions */
-      virtual const std::string layer_type() const;
-      virtual void Setup(const std::vector<size_t>&,
-                         const LayerConf&);
-      virtual void Setup(const std::vector<std::vector<size_t>>&,
-                         const LayerConf&);
+//      virtual void Setup(const std::vector<vector<size_t>>&, const string&);
+      virtual void Setup(const std::vector<size_t>& in_sample_shape,
+                         const std::string& proto_str);
+      const std::vector<Tensor> param_values();
+      virtual const std::vector<size_t> GetOutputSampleShape() const;
       virtual void ToDevice(std::shared_ptr<Device> device);
       virtual void AsType(DataType dtype);
-      virtual void ToProto(LayerConf* conf) const;
-
-      virtual const Tensor
-      Forward(int flag, const Tensor& input);
-      virtual const std::vector<Tensor>
-      Forward(int flag, const std::vector<Tensor>& inputs);
-      virtual const std::pair<Tensor, std::vector<Tensor>>
-      Backward(int flag, const Tensor& grad);
+      virtual const Tensor Forward(int flag, const Tensor& input);
+      virtual const std::vector<Tensor> Forward(
+          int flag, const std::vector<Tensor>& inputs);
+      virtual const std::pair<Tensor, std::vector<Tensor>> Backward(
+          int flag, const Tensor& grad);
       virtual const std::pair<std::vector<Tensor>, std::vector<Tensor>>
       Backward(int flag, const vector<Tensor>& grads);
+
   };
+  std::shared_ptr<Layer> CreateLayer(const std::string& type);
+  const std::vector<std::string> GetRegisteredLayers();
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/790b7b4c/src/python/swig/model_loss.i
----------------------------------------------------------------------
diff --git a/src/python/swig/model_loss.i b/src/python/swig/model_loss.i
new file mode 100644
index 0000000..864ad88
--- /dev/null
+++ b/src/python/swig/model_loss.i
@@ -0,0 +1,62 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+/*interface file for swig */
+
+%module model_loss
+%include "std_string.i"
+%{
+#include "singa/model/loss.h"
+  using singa::Tensor;
+%}
+
+namespace singa {
+class Loss {
+public:
+  Loss() = default;
+  virtual ~Loss() {}
+
+  virtual Tensor Forward(int flag, const Tensor &prediction,
+                         const Tensor &target) = 0;
+
+  float Evaluate(int flag, const Tensor &prediction, const Tensor &target);
+
+  /// Compute the gradients of the loss values w.r.t. the prediction.
+  virtual Tensor Backward() = 0;
+};
+
+class MSE : public Loss {
+public:
+  Tensor Forward(int flag, const Tensor &prediction, const Tensor &target)
+      override;
+
+  Tensor Backward() override;
+};
+
+class SoftmaxCrossEntropy : public Loss {
+public:
+  Tensor Forward(int flag, const Tensor &prediction, const Tensor &target)
+      override;
+
+  Tensor Backward() override;
+};
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/790b7b4c/src/python/swig/model_metric.i
----------------------------------------------------------------------
diff --git a/src/python/swig/model_metric.i b/src/python/swig/model_metric.i
new file mode 100644
index 0000000..9d93cd0
--- /dev/null
+++ b/src/python/swig/model_metric.i
@@ -0,0 +1,43 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+/*interface file for swig */
+
+%module model_metric
+%{
+#include "singa/model/metric.h"
+using singa::Tensor;
+%}
+
+namespace singa {
+class Metric {
+ public:
+  Metric() = default;
+  virtual ~Metric() {}
+  virtual Tensor Forward(const Tensor& prediction, const Tensor& target) = 0;
+  float Evaluate(const Tensor& prediction, const Tensor& target);
+};
+class Accuracy : public Metric {
+ public:
+  Tensor Forward(const Tensor& prediction, const Tensor& target);
+};
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/790b7b4c/src/python/swig/model_optimizer.i
----------------------------------------------------------------------
diff --git a/src/python/swig/model_optimizer.i 
b/src/python/swig/model_optimizer.i
index ee60f54..78b30b8 100644
--- a/src/python/swig/model_optimizer.i
+++ b/src/python/swig/model_optimizer.i
@@ -47,7 +47,7 @@ class Optimizer {
   virtual ~Optimizer() = default;
   void Setup(const std::string& str);
   virtual void Apply(int step, float lr, const std::string& name,
-    const Tensor& grad, Tensor* value) = 0;
+    const Tensor& grad, Tensor& value) = 0;
 };
 inline std::shared_ptr<Optimizer> CreateOptimizer(const std::string& type);
 
@@ -55,7 +55,7 @@ class Constraint {
  public:
   Constraint() = default;
   void Setup(const std::string& conf_str);
-  void Apply(int step, Tensor* grad, Tensor* value);
+  void Apply(int step, Tensor& grad, Tensor& value);
 };
 
 inline std::shared_ptr<Constraint> CreateConstraint(const std::string& type);
@@ -64,7 +64,7 @@ class Regularizer {
  public:
   Regularizer() = default;
   void Setup(const std::string& conf_str);
-  void Apply(int step, Tensor* grad, Tensor* value);
+  void Apply(int step, Tensor& grad, Tensor& value);
 };
 inline std::shared_ptr<Regularizer> CreateRegularizer(const std::string& type);
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/790b7b4c/src/python/swig/singa.i
----------------------------------------------------------------------
diff --git a/src/python/swig/singa.i b/src/python/swig/singa.i
index dbf621a..3f12569 100644
--- a/src/python/swig/singa.i
+++ b/src/python/swig/singa.i
@@ -26,3 +26,5 @@
 %include "core_device.i"
 %include "model_layer.i"
 %include "model_optimizer.i"
+%include "model_loss.i"
+%include "model_metric.i"

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/790b7b4c/src/python/tensor.py
----------------------------------------------------------------------
diff --git a/src/python/tensor.py b/src/python/tensor.py
deleted file mode 100644
index 099e706..0000000
--- a/src/python/tensor.py
+++ /dev/null
@@ -1,496 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# =============================================================================
-"""
-This script includes Tensor class and its methods for python users
-to call singa::Tensor and its methods
-"""
-
-import numpy as np
-from proto.core_pb2 import *
-from . import singa_wrap as singa
-
-
-class Tensor(object):
-    ''' Class and member functions for singa::Tensor
-    '''
-
-    def __init__(self, shape=None, device=None, dtype=kFloat32):
-        ''' shape = (tuple)
-        '''
-        if shape is None:
-            # call constructor of singa::Tensor
-            self.singa_tensor = singa.Tensor()
-            return
-        else:
-            assert type(shape) == tuple, 'shape should be tuple'
-            vs = _tuple_to_vector(shape)
-            if device is None:
-                self.singa_tensor = singa.Tensor(vs, dtype)
-            else:
-                self.singa_tensor = singa.Tensor(vs, device, dtype)
-            self.tuple_shape = shape
-            self.device = device
-            self.dtype = dtype
-
-    def copy_from_numpy(self, np_array, offset=0):
-        ''' this method stores the values of numpy array into tensor data
-            from the position of offset
-        '''
-        assert np_array.size == self.size(), 'tensor shape should be the same'
-        if not np_array.ndim == 1:
-            np_array = np_array.flatten()
-        dt = np_array.dtype
-        if dt == np.float32:
-            self.singa_tensor.floatCopyDataFromHostPtr(np_array, offset)
-        else:
-            print 'Not implemented yet for ', dt
-
-    def data_type(self):
-        return self.singa_tensor.data_type()
-
-    def shape(self, axis=None):
-        if axis is None:
-            return self.singa_tensor.shape()
-        else:
-            return self.singa_tensor.shape(axis)
-
-    def ndim(self):
-        return self.singa_tensor.nDim()
-
-    def is_transpose(self):
-        return self.singa_tensor.transpose()
-
-    def size(self):
-        return self.singa_tensor.Size()
-
-    def memsize(self):
-        return self.singa_tensor.MemSize()
-
-    def reshape(self, shape):
-        assert product(self.tuple_shape) == product(shape), \
-               'product of shape should be equal'
-        self.tuple_shape = shape
-        self.singa_tensor.Reshape(_tuple_to_vector(shape))
-
-    def reset_like(self, t):
-        self.singa_tensor.ResetLike(t.singa_tensor)
-
-    def as_type(self, dtype):
-        self.singa_tensor.AsType(dtype)
-
-    def to_device(self, device):
-        self.singa_tensor.ToDevice(device)
-
-    def to_host(self):
-        self.singa_tensor.ToHost()
-
-    def nrm2(self):
-        return self.singa_tensor.L2()
-
-    def set_value(self, x):
-        if type(x) == float:
-            self.singa_tensor.floatSetValue(x)
-
-    def copy_data(self, t):
-        self.singa_tensor.CopyData(t.singa_tensor)
-
-    def clone(self):
-        ''' it does deep copy
-            call singa::Tensor::Clone()
-        '''
-        return _call_singa_func(self.singa_tensor.Clone)
-
-    def transpose(self):
-        ''' shallow copy, negate the transpose field
-            call singa::Tensor::T()
-        '''
-        return _call_singa_func(self.singa_tensor.T)
-
-    def copy(self):
-        ''' shallow copy
-            call copy constructor of singa::Tensor
-        '''
-        return _call_singa_func(singa.Tensor, self.singa_tensor)
-
-    def deepcopy(self):
-        ''' deep copy
-            call singa::Tensor::Clone()
-        '''
-        return self.clone()
-
-    def bernoulli(self, p):
-        if type(p) == float:
-            singa.floatBernoulli(p, self.singa_tensor)
-
-    def gaussian(self, mean, std):
-        if type(mean) == float:
-            singa.floatGaussian(mean, std, self.singa_tensor)
-
-    def uniform(self, low, high):
-        if type(low) == float:
-            singa.floatUniform(low, high, self.singa_tensor)
-
-    def add_column(self, v):
-        singa.AddColumn(v.singa_tensor, self.singa_tensor)
-
-    def add_row(self, v):
-        singa.AddRow(v.singa_tensor, self.singa_tensor)
-
-    def div_column(self, v):
-        singa.DivColumn(v.singa_tensor, self.singa_tensor)
-
-    def div_row(self, v):
-        singa.DivRow(v.singa_tensor, self.singa_tensor)
-
-    def mult_column(self, v):
-        singa.MultColumn(v.singa_tensor, self.singa_tensor)
-
-    def mult_row(self, v):
-        singa.MultRow(v.singa_tensor, self.singa_tensor)
-
-    '''
-    python operators (+=, -=, *=, /=) for singa::Tensor unary operators
-    '''
-    def __iadd__(self, x):
-        if type(x) == Tensor:
-            self.singa_tensor += x.singa_tensor
-        else:
-            self.singa_tensor += x
-        return self
-
-    def __isub__(self, x):
-        if type(x) == Tensor:
-            self.singa_tensor -= x.singa_tensor
-        else:
-            self.singa_tensor -= x
-        return self
-
-    def __imul__(self, x):
-        if type(x) == Tensor:
-            self.singa_tensor *= x.singa_tensor
-        else:
-            self.singa_tensor *= x
-        return self
-
-    def __idiv__(self, x):
-        if type(x) == Tensor:
-            self.singa_tensor /= x.singa_tensor
-        else:
-            self.singa_tensor /= x
-        return self
-
-    '''
-    python operators (+, -, *, /, <, <=, >, >=) for singa binary operators
-    '''
-    def __add__(self, rhs):
-        if isinstance(rhs, Tensor):
-            return _call_singa_func(singa.Add_TT,
-                                    self.singa_tensor, rhs.singa_tensor)
-        else:
-            return _call_singa_func(singa.Add_Tf,
-                                    self.singa_tensor, rhs)
-
-    def __sub__(self, rhs):
-        if isinstance(rhs, Tensor):
-            return _call_singa_func(singa.Sub_TT,
-                                    self.singa_tensor, rhs.singa_tensor)
-        else:
-            return _call_singa_func(singa.Sub_Tf,
-                                    self.singa_tensor, rhs)
-
-    def __mul__(self, rhs):
-        if isinstance(rhs, Tensor):
-            return _call_singa_func(singa.EltwiseMul_TT,
-                                    self.singa_tensor, rhs.singa_tensor)
-        else:
-            return _call_singa_func(singa.EltwiseMul_Tf,
-                                    self.singa_tensor, rhs)
-
-    def __div__(self, rhs):
-        if isinstance(rhs, Tensor):
-            return _call_singa_func(singa.Div_TT,
-                                    self.singa_tensor, rhs.singa_tensor)
-        else:
-            return _call_singa_func(singa.Div_Tf,
-                                    self.singa_tensor, rhs)
-
-    def __lt__(self, rhs):
-        return _call_singa_func(singa.LT_Tf, self.singa_tensor, rhs)
-
-    def __le__(self, rhs):
-        return _call_singa_func(singa.LE_Tf, self.singa_tensor, rhs)
-
-    def __gt__(self, rhs):
-        return _call_singa_func(singa.GT_Tf, self.singa_tensor, rhs)
-
-    def __ge__(self, rhs):
-        return _call_singa_func(singa.GE_Tf, self.singa_tensor, rhs)
-
-
-''' python functions for global functions in Tensor.h
-'''
-
-
-def product(shape):
-    return reduce(lambda x, y: x * y, shape)
-
-
-def sizeof(dtype):
-    return singa.SizeOf(dtype)
-
-
-def reshape(t, s):
-    return _call_singa_func(singa.Reshape, t.singa_tensor, s)
-
-
-def copy_data_to_from(dst, src, size, dst_offset=0, src_offset=0):
-    singa.CopyDataToFrom(dst.singa_tensor, src.singa_tensor, size,
-                         dst_offset, src_offset)
-
-
-def from_numpy(np_array):
-    ret = Tensor(np_array.shape)
-    ret.copy_from_numpy(np_array)
-    return ret
-
-
-def to_numpy(t):
-    ''' this method gets the values of tensor data and
-        returns it as numpy array
-        TODO(wangwei) clone t to host
-    '''
-    if t.dtype == kFloat32:
-        np_array = t.singa_tensor.floatGetValue(int(t.size()))
-    else:
-        print 'Not implemented yet for ', t.dtype
-    return np_array.reshape(t.tuple_shape)
-
-
-def abs(t):
-    return _call_singa_func(singa.Abs, t.singa_tensor)
-
-
-def exp(t):
-    return _call_singa_func(singa.Exp, t.singa_tensor)
-
-
-def log(t):
-    return _call_singa_func(singa.Log, t.singa_tensor)
-
-
-def relu(t):
-    return _call_singa_func(singa.ReLU, t.singa_tensor)
-
-
-def sigmoid(t):
-    return _call_singa_func(singa.Sigmoid, t.singa_tensor)
-
-
-def square(t):
-    return _call_singa_func(singa.Square, t.singa_tensor)
-
-
-def tanh(t):
-    return _call_singa_func(singa.Tanh, t.singa_tensor)
-
-
-def sum(t, axis=None):
-    if axis is None:
-        return singa.floatSum(t.singa_tensor)
-    else:
-        return _call_singa_func(singa.Sum, t.singa_tensor, axis)
-
-
-def pow(t, x, out=None):
-    if out is None:
-        if isinstance(x, Tensor):
-            return _call_singa_func(singa.Pow, t.singa_tensor, x.singa_tensor)
-        else:
-            return _call_singa_func(singa.Pow_f, t.singa_tensor, x)
-    else:
-        if isinstance(x, Tensor):
-            singa.Pow(t.singa_tensor, x.singa_tensor, out.singa_tensor)
-        else:
-            singa.Pow_f_out(t.singa_tensor, x, out.singa_tensor)
-        return out
-
-
-def average(t, axis=0):
-    return _call_singa_func(singa.Average, t.singa_tensor, axis)
-
-
-def softmax(t, out=None):
-    if out is None:
-        return _call_singa_func(singa.SoftMax, t.singa_tensor)
-    else:
-        singa.SoftMax(t.singa_tensor, out.singa_tensor)
-        return out
-
-
-def lt(t, x):
-    return t < x
-
-
-def le(t, x):
-    return t <= x
-
-
-def gt(t, x):
-    return t > x
-
-
-def ge(t, x):
-    return t >= x
-
-
-def add(lhs, rhs, ret=None):
-    if ret is None:
-        # call Tensor.__add__()
-        return lhs + rhs
-    else:
-        if isinstance(rhs, Tensor):
-            singa.Add(lhs.singa_tensor, rhs.singa_tensor, ret.singa_tensor)
-        else:
-            singa.Add_Tf_out(lhs.singa_tensor, rhs, ret.singa_tensor)
-        return ret
-
-
-def sub(lhs, rhs, ret=None):
-    if ret is None:
-        # call Tensor.__sub__()
-        return lhs - rhs
-    else:
-        if isinstance(rhs, Tensor):
-            singa.Sub(lhs.singa_tensor, rhs.singa_tensor, ret.singa_tensor)
-        else:
-            singa.Sub_Tf_out(lhs.singa_tensor, rhs, ret.singa_tensor)
-        return ret
-
-
-def eltwise_mult(lhs, rhs, ret=None):
-    if ret is None:
-        # call Tensor.__mul__()
-        return lhs * rhs
-    else:
-        if isinstance(rhs, Tensor):
-            singa.EltwiseMult(lhs.singa_tensor, rhs.singa_tensor,
-                              ret.singa_tensor)
-        else:
-            singa.EltwiseMult_Tf_out(lhs.singa_tensor, rhs,
-                                     ret.singa_tensor)
-        return ret
-
-
-def mult(A, B, C=None, alpha=1.0, beta=0.0):
-    '''
-    This function returns C = alpha * A * B + beta * C
-    '''
-    if C is None:
-        return _call_singa_func(singa.Mult, A.singa_tensor, B.singa_tensor)
-    else:
-        singa.floatMult(alpha, A.singa_tensor, B.singa_tensor,
-                        beta, C.singa_tensor)
-        return C
-
-
-def div(lhs, rhs, ret=None):
-    if ret is None:
-        # call Tensor.__div__()
-        return lhs / rhs
-    else:
-        if isinstance(rhs, Tensor):
-            singa.Div(lhs.singa_tensor, rhs.singa_tensor, ret.singa_tensor)
-        else:
-            singa.Div_Tf_out(lhs.singa_tensor, rhs, ret.singa_tensor)
-        return ret
-
-
-def axpy(alpha, x, y):
-    if type(alpha) == float:
-        singa.floatAxpy(alpha, x.singa_tensor, y.singa_tensor)
-    return y
-
-
-def bernoulli(p, t):
-    if type(p) == float:
-        singa.floatBernoulli(p, t.singa_tensor)
-    return t
-
-
-def gaussian(mean, std, t):
-    if type(mean) == float:
-        singa.floatGaussian(mean, std, t.singa_tensor)
-    return t
-
-
-def uniform(low, high, t):
-    if type(low) == float:
-        singa.floatUniform(low, high, t.singa_tensor)
-    return t
-
-
-def add_column(alpha, v, beta, M):
-    singa.floatAddColumn(alpha, beta, v.singa_tensor, M.singa_tensor)
-    return M
-
-
-def add_row(alpha, v, beta, M):
-    singa.floatAddRow(alpha, beta, v.singa_tensor, M.singa_tensor)
-    return M
-
-
-def sum_columns(M):
-    assert M.ndim() == 2, 'M.nDim() is supposed to be 2'
-    nb_col = M.shape(0)
-    ret = Tensor((nb_col, 1))
-    singa.SumColumns(M.singa_tensor, ret.singa_tensor)
-    return ret
-
-
-def sum_rows(M):
-    assert M.ndim() == 2, 'M.nDim() is supposed to be 2'
-    nb_row = M.shape(1)
-    ret = Tensor((1, nb_row))
-    singa.SumRows(M.singa_tensor, ret.singa_tensor)
-    return ret
-
-
-''' private functions, internally used
-'''
-
-
-def _tuple_to_vector(tshape):
-    ''' this function converts tuple to std::vector<int>
-    '''
-    vs = singa.Shape(len(tshape))
-    for i in range(len(tshape)):
-        vs[i] = tshape[i]
-    return vs
-
-
-def _call_singa_func(_singa_func, *args):
-    ''' this function calls singa global functions that returns Tensor
-        and create new python Tensor instance
-        e.g., Tensor [singa_func](args...)
-    '''
-    new_t = Tensor()
-    new_t.singa_tensor = _singa_func(*args)
-    new_t.tuple_shape = new_t.singa_tensor.shape()
-    new_t.device = new_t.singa_tensor.device()
-    new_t.dtype = new_t.singa_tensor.data_type()
-    return new_t

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/790b7b4c/test/python/test_layer.py
----------------------------------------------------------------------
diff --git a/test/python/test_layer.py b/test/python/test_layer.py
new file mode 100644
index 0000000..7e1059e
--- /dev/null
+++ b/test/python/test_layer.py
@@ -0,0 +1,194 @@
+import sys
+import os
+import unittest
+import numpy as np
+
+sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python'))
+
+from singa import layer
+from singa import device
+from singa import tensor
+from singa.proto import model_pb2
+
+
+def _tuple_to_string(t):
+    lt = [str(x) for x in t]
+    return '(' + ', '.join(lt) + ')'
+
+
+class TestPythonLayer(unittest.TestCase):
+
+    def check_shape(self, actual, expect):
+        self.assertEqual(actual, expect, 'shape mismatch, actual shape is %s'
+                         ' exepcted is %s' % (_tuple_to_string(actual),
+                                              _tuple_to_string(expect))
+                         )
+
+    def setUp(self):
+        self.w = {'init': 'Xavier', 'regularizer': 1e-4}
+        self.b = {'init': 'Constant', 'value': 0}
+        self.sample_shape = None
+
+    def test_conv2D_shape(self):
+        in_sample_shape = (3, 224, 224)
+        conv = layer.Conv2D('conv', 64, 3, 1, W_specs=self.w, b_specs=self.b,
+                            input_sample_shape=in_sample_shape)
+        out_sample_shape = conv.get_output_sample_shape()
+        self.check_shape(out_sample_shape, (64, 224, 224))
+
+    def test_conv2D_forward_backward(self):
+        in_sample_shape = (1, 3, 3)
+        conv = layer.Conv2D('conv', 1, 3, 2, W_specs=self.w, b_specs=self.b,
+                            pad=1, input_sample_shape=in_sample_shape)
+        cuda = device.create_cuda_gpu()
+        conv.to_device(cuda)
+        params = conv.param_values()
+
+        raw_x = np.arange(9, dtype=np.float32) + 1
+        x = tensor.from_numpy(raw_x)
+        x.reshape((1, 1, 3, 3))
+        w = np.array([1, 1, 0, 0, 0, -1, 0, 1, 0], dtype=np.float32)
+        params[0].copy_from_numpy(w)
+        params[1].set_value(1.0)
+
+        x.to_device(cuda)
+        y = conv.forward(model_pb2.kTrain, x)
+        y.to_host()
+        npy = tensor.to_numpy(y).flatten()
+
+        self.assertAlmostEqual(3.0, npy[0])
+        self.assertAlmostEqual(7.0, npy[1])
+        self.assertAlmostEqual(-3.0, npy[2])
+        self.assertAlmostEqual(12.0, npy[3])
+
+        dy = np.asarray([0.1, 0.2, 0.3, 0.4], 
dtype=np.float32).reshape(y.shape)
+        grad = tensor.from_numpy(dy)
+        grad.to_device(cuda)
+        (dx, [dw, db]) = conv.backward(model_pb2.kTrain, grad)
+        dx.to_host()
+        dw.to_host()
+        dx = tensor.to_numpy(dx).flatten()
+        dw = tensor.to_numpy(dw).flatten()
+        dy = dy.flatten()
+        self.assertAlmostEquals(dy[0] * w[4], dx[0])
+        self.assertAlmostEquals(dy[0] * w[5] + dy[1] * w[3], dx[1])
+        self.assertAlmostEquals(dy[1] * w[4], dx[2])
+        self.assertAlmostEquals(dy[0] * w[7] + dy[2] * w[1], dx[3])
+        self.assertAlmostEquals(
+            dy[0] *
+            w[8] +
+            dy[1] *
+            w[6] +
+            dy[2] *
+            w[2] +
+            dy[3] *
+            w[0],
+            dx[4])
+        self.assertAlmostEquals(dy[1] * w[7] + dy[3] * w[1], dx[5])
+        self.assertAlmostEquals(dy[2] * w[4], dx[6])
+        self.assertAlmostEquals(dy[2] * w[5] + dy[3] * w[3], dx[7])
+        self.assertAlmostEquals(dy[3] * w[4], dx[8])
+
+        self.assertAlmostEquals(dy[3] * raw_x[4], dw[0])
+        self.assertAlmostEquals(dy[3] * raw_x[5] + dy[2] * raw_x[3], dw[1])
+        self.assertAlmostEquals(dy[2] * raw_x[4], dw[2])
+        self.assertAlmostEquals(dy[1] * raw_x[1] + dy[3] * raw_x[7], dw[3])
+        self.assertAlmostEquals(
+            dy[0] *
+            raw_x[0] +
+            dy[1] *
+            raw_x[2] +
+            dy[2] *
+            raw_x[6] +
+            dy[3] *
+            raw_x[8],
+            dw[4], 5)
+        self.assertAlmostEquals(dy[0] * raw_x[1] + dy[2] * raw_x[7], dw[5])
+        self.assertAlmostEquals(dy[1] * raw_x[4], dw[6])
+        self.assertAlmostEquals(dy[0] * raw_x[3] + dy[1] * raw_x[5], dw[7])
+        self.assertAlmostEquals(dy[0] * raw_x[4], dw[8])
+
+    def test_conv1D(self):
+        in_sample_shape = (224,)
+        conv = layer.Conv1D('conv', 64, 3, 1, W_specs=self.w, b_specs=self.b,
+                            pad=1, input_sample_shape=in_sample_shape)
+        out_sample_shape = conv.get_output_sample_shape()
+        self.check_shape(out_sample_shape, (64, 224,))
+
+    def test_max_pooling2D(self):
+        in_sample_shape = (64, 224, 224)
+        pooling = layer.MaxPooling2D('pool', 3, 2,
+                                     input_sample_shape=in_sample_shape)
+        out_sample_shape = pooling.get_output_sample_shape()
+        self.check_shape(out_sample_shape, (64, 112, 112))
+
+    def test_max_pooling1D(self):
+        in_sample_shape = (224,)
+        pooling = layer.MaxPooling1D('pool', 3, 2,
+                                     input_sample_shape=in_sample_shape)
+        out_sample_shape = pooling.get_output_sample_shape()
+        self.check_shape(out_sample_shape, (112,))
+
+    def test_avg_pooling2D(self):
+        in_sample_shape = (64, 224, 224)
+        pooling = layer.AvgPooling2D('pool', 3, 2,
+                                     input_sample_shape=in_sample_shape)
+        out_sample_shape = pooling.get_output_sample_shape()
+        self.check_shape(out_sample_shape, (64, 112, 112))
+
+    def test_avg_pooling1D(self):
+        in_sample_shape = (224,)
+        pooling = layer.AvgPooling1D('pool', 3, 2,
+                                     input_sample_shape=in_sample_shape)
+        out_sample_shape = pooling.get_output_sample_shape()
+        self.check_shape(out_sample_shape, (112,))
+
+    def test_batch_normalization(self):
+        in_sample_shape = (3, 224, 224)
+        bn = layer.BatchNormalization('bn', input_sample_shape=in_sample_shape)
+        out_sample_shape = bn.get_output_sample_shape()
+        self.check_shape(out_sample_shape, in_sample_shape)
+
+    def test_lrn(self):
+        in_sample_shape = (3, 224, 224)
+        lrn = layer.LRN('lrn', input_sample_shape=in_sample_shape)
+        out_sample_shape = lrn.get_output_sample_shape()
+        self.check_shape(out_sample_shape, in_sample_shape)
+
+    def test_dense(self):
+        dense = layer.Dense('ip', 32, input_sample_shape=(64,))
+        out_sample_shape = dense.get_output_sample_shape()
+        self.check_shape(out_sample_shape, (32,))
+
+    def test_dropout(self):
+        input_sample_shape = (64, 1, 12)
+        dropout = layer.Dropout('drop', input_sample_shape=input_sample_shape)
+        out_sample_shape = dropout.get_output_sample_shape()
+        self.check_shape(out_sample_shape, input_sample_shape)
+
+    def test_activation(self):
+        input_sample_shape = (64, 1, 12)
+        act = layer.Activation('act', input_sample_shape=input_sample_shape)
+        out_sample_shape = act.get_output_sample_shape()
+        self.check_shape(out_sample_shape, input_sample_shape)
+
+    def test_softmax(self):
+        input_sample_shape = (12,)
+        softmax = layer.Softmax('soft', input_sample_shape=input_sample_shape)
+        out_sample_shape = softmax.get_output_sample_shape()
+        self.check_shape(out_sample_shape, input_sample_shape)
+
+    def test_flatten(self):
+        input_sample_shape = (64, 1, 12)
+        flatten = layer.Flatten('flat', input_sample_shape=input_sample_shape)
+        out_sample_shape = flatten.get_output_sample_shape()
+        self.check_shape(out_sample_shape, (64 * 1 * 12, ))
+
+        flatten = layer.Flatten('flat', axis=2,
+                                input_sample_shape=input_sample_shape)
+        out_sample_shape = flatten.get_output_sample_shape()
+        self.check_shape(out_sample_shape, (12,))
+
+
+if __name__ == '__main__':
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/790b7b4c/test/python/test_optimizer.py
----------------------------------------------------------------------
diff --git a/test/python/test_optimizer.py b/test/python/test_optimizer.py
index fa062c8..afdf337 100644
--- a/test/python/test_optimizer.py
+++ b/test/python/test_optimizer.py
@@ -26,7 +26,7 @@ import singa.tensor as tensor
 import singa.optimizer as opt
 import singa.device as device
 
-cuda = device.Platform.create_cuda_gpu()
+cuda = device.create_cuda_gpu()
 
 
 class TestOptimizer(unittest.TestCase):

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/790b7b4c/test/python/test_tensor.py
----------------------------------------------------------------------
diff --git a/test/python/test_tensor.py b/test/python/test_tensor.py
index c999705..4d8b940 100644
--- a/test/python/test_tensor.py
+++ b/test/python/test_tensor.py
@@ -15,148 +15,121 @@
 # specific language governing permissions and limitations
 # under the License.
 # =============================================================================
+
 import sys
 import os
 import math
 import unittest
-import numpy as np
 
-sys.path.append(os.path.join(os.path.dirname(__file__),'../../build/python'))
+sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python'))
 
-from singa.tensor import *
-from singa.device import *
 
-from singa.proto.core_pb2 import *
+from singa import tensor
+from singa.proto import core_pb2
 
 
 class TestTensorMethods(unittest.TestCase):
 
     def setUp(self):
-        self.shape = (3, 2)
-        self.t = Tensor(self.shape)
-        self.s = Tensor(self.shape)
+        self.shape = (2, 3)
+        self.t = tensor.Tensor(self.shape)
+        self.s = tensor.Tensor(self.shape)
 
     def test_tensor_fields(self):
         t = self.t
         shape = self.shape
-        self.assertTupleEqual(t.shape(), shape)
-        self.assertEqual(t.shape(0), shape[0])
-        self.assertEqual(t.shape(1), shape[1])
-        self.assertEqual(product(shape), 3*2)
+        self.assertTupleEqual(t.shape, shape)
+        self.assertEqual(t.shape[0], shape[0])
+        self.assertEqual(t.shape[1], shape[1])
+        self.assertEqual(tensor.product(shape), 2*3)
         self.assertEqual(t.ndim(), 2)
-        self.assertEqual(t.size(), 3*2)
-        self.assertEqual(t.memsize(), 3*2*sizeof(kFloat32))
+        self.assertEqual(t.size(), 2*3)
+        self.assertEqual(t.memsize(), 2*3*tensor.sizeof(core_pb2.kFloat32))
         self.assertFalse(t.is_transpose())
 
     def test_unary_operators(self):
         t = self.t
-        arr = np.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], dtype=np.float32)
-        t.copy_from_numpy(arr)
-        npary = to_numpy(t)
-        self.assertAlmostEqual(npary[0, 0], arr[0, 0])
-        self.assertAlmostEqual(npary[0, 1], arr[0, 1])
-        self.assertAlmostEqual(npary[2, 1], arr[2, 1])
+        self.assertAlmostEqual(tensor.to_numpy(t)[0, 0], 0.0)
         t += 1.23
-        npary = to_numpy(t)
-        self.assertAlmostEqual(npary[0, 0], 1.0+1.23)
+        self.assertAlmostEqual(tensor.to_numpy(t)[0, 0], 1.23)
         t -= 0.23
-        npary = to_numpy(t)
-        self.assertAlmostEqual(npary[0, 0], 2.23-0.23)
+        self.assertAlmostEqual(tensor.to_numpy(t)[0, 0], 1.23-0.23)
         t *= 2.5
-        npary = to_numpy(t)
-        self.assertAlmostEqual(npary[0, 0], (2.23-0.23)*2.5)
+        self.assertAlmostEqual(tensor.to_numpy(t)[0, 0], (1.23-0.23)*2.5)
         t /= 2
-        npary = to_numpy(t)
-        self.assertAlmostEqual(npary[0, 0], (2.23-0.23)*2.5/2)
+        self.assertAlmostEqual(tensor.to_numpy(t)[0, 0], (1.23-0.23)*2.5/2)
 
     def test_binary_operators(self):
         t = self.t
-        arr = np.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], dtype=np.float32)
-        t.copy_from_numpy(arr)
+        t += 3.2
         s = self.s
-        arr = np.array([[4.0, 3.0], [3.0, 2.0], [2.0, 1.0]], dtype=np.float32)
-        s.copy_from_numpy(arr)
+        s += 2.1
         a = t + s
-        self.assertAlmostEqual(to_numpy(a)[0, 0], 1.0+4.0)
+        self.assertAlmostEqual(tensor.to_numpy(a)[0, 0], 3.2+2.1, 5)
         a = t - s
-        self.assertAlmostEqual(to_numpy(a)[0, 0], 1.0-4.0)
+        self.assertAlmostEqual(tensor.to_numpy(a)[0, 0], 3.2-2.1, 5)
         a = t * s
-        self.assertAlmostEqual(to_numpy(a)[0, 0], 1.0*4.0)
+        self.assertAlmostEqual(tensor.to_numpy(a)[0, 0], 3.2*2.1, 5)
+        ''' not implemented yet
         a = t / s
-        self.assertAlmostEqual(to_numpy(a)[0, 0], 1.0/4.0)
+        self.assertAlmostEqual(tensor.to_numpy(a)[0,0], 3.2/2.1, 5)
+        '''
 
     def test_comparison_operators(self):
         t = self.t
-        t.set_value(3.45)
+        t += 3.45
         a = t < 3.45
-        self.assertEqual(to_numpy(a)[0, 0], 0)
+        self.assertEqual(tensor.to_numpy(a)[0, 0], 0)
         a = t <= 3.45
-        self.assertEqual(to_numpy(a)[0, 0], 1)
+        self.assertEqual(tensor.to_numpy(a)[0, 0], 1)
         a = t > 3.45
-        self.assertEqual(to_numpy(a)[0, 0], 0)
+        self.assertEqual(tensor.to_numpy(a)[0, 0], 0)
         a = t >= 3.45
-        self.assertEqual(to_numpy(a)[0, 0], 1)
-        a = lt(t, 3.45)
-        self.assertEqual(to_numpy(a)[0, 0], 0)
-        a = le(t, 3.45)
-        self.assertEqual(to_numpy(a)[0, 0], 1)
-        a = gt(t, 3.45)
-        self.assertEqual(to_numpy(a)[0, 0], 0)
-        a = ge(t, 3.45)
-        self.assertEqual(to_numpy(a)[0, 0], 1)
-
-    def test_tensor_manipulation(self):
-        t = self.t
-        arr = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32)
-        t.copy_from_numpy(arr)
-        s = Tensor((3, 1))
-        arr = np.array([7, 8, 9], dtype=np.float32)
-        s.copy_from_numpy(arr)
-        t.add_column(s)
-        self.assertEqual(to_numpy(t)[0, 0], 1+7)
-        self.assertEqual(to_numpy(t)[1, 0], 3+8)
-        self.assertEqual(to_numpy(t)[1, 1], 4+8)
-
-        arr = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32)
-        t.copy_from_numpy(arr)
-        add_column(2, s, 3, t)
-        self.assertEqual(to_numpy(t)[0, 0], 3*1+2*7)
-        self.assertEqual(to_numpy(t)[1, 0], 3*3+2*8)
-        self.assertEqual(to_numpy(t)[1, 1], 3*4+2*8)
-
-    def test_random_operations(self):
-        # TODO(chonho)
-        pass
+        self.assertEqual(tensor.to_numpy(a)[0, 0], 1)
+        a = tensor.lt(t, 3.45)
+        self.assertEqual(tensor.to_numpy(a)[0, 0], 0)
+        a = tensor.le(t, 3.45)
+        self.assertEqual(tensor.to_numpy(a)[0, 0], 1)
+        a = tensor.gt(t, 3.45)
+        self.assertEqual(tensor.to_numpy(a)[0, 0], 0)
+        a = tensor.ge(t, 3.45)
+        self.assertEqual(tensor.to_numpy(a)[0, 0], 1)
 
     def test_tensor_copy(self):
-        t = Tensor((2, 3))
-        t.set_value(1.23)
-        self.assertAlmostEqual(to_numpy(t)[0, 0], 1.23)
+        t = tensor.Tensor((2, 3))
+        t += 1.23
+        self.assertAlmostEqual(tensor.to_numpy(t)[0, 0], 1.23)
         tc = t.copy()
         tdc = t.deepcopy()
-        self.assertAlmostEqual(to_numpy(tc)[0, 0], 1.23)
-        self.assertAlmostEqual(to_numpy(tdc)[0, 0], 1.23)
+        self.assertAlmostEqual(tensor.to_numpy(tc)[0, 0], 1.23)
+        self.assertAlmostEqual(tensor.to_numpy(tdc)[0, 0], 1.23)
         t += 1.23
-        self.assertAlmostEqual(to_numpy(t)[0, 0], 2.46)
-        self.assertAlmostEqual(to_numpy(tc)[0, 0], 2.46)
-        self.assertAlmostEqual(to_numpy(tdc)[0, 0], 1.23)
+        self.assertAlmostEqual(tensor.to_numpy(t)[0, 0], 2.46)
+        self.assertAlmostEqual(tensor.to_numpy(tc)[0, 0], 2.46)
+        self.assertAlmostEqual(tensor.to_numpy(tdc)[0, 0], 1.23)
 
     def test_copy_data(self):
         t = self.t
-        t.set_value(1.23)
+        t += 1.23
         s = self.s
-        s.set_value(5.43)
-        self.assertAlmostEqual(to_numpy(t)[0, 0], 1.23)
-        copy_data_to_from(t, s, 2)
-        self.assertAlmostEqual(to_numpy(t)[0, 0], 5.43, 5)
-        self.assertAlmostEqual(to_numpy(t)[0, 1], 5.43, 5)
-        self.assertAlmostEqual(to_numpy(t)[1, 0], 1.23, 5)
+        s += 5.43
+        self.assertAlmostEqual(tensor.to_numpy(t)[0, 0], 1.23)
+        tensor.copy_data_to_from(t, s, 2)
+        self.assertAlmostEqual(tensor.to_numpy(t)[0, 0], 5.43, 5)
+        self.assertAlmostEqual(tensor.to_numpy(t)[0, 1], 5.43, 5)
+        self.assertAlmostEqual(tensor.to_numpy(t)[0, 2], 1.23)
 
     def test_global_method(self):
         t = self.t
-        t.set_value(12.34)
-        a = log(t)
-        self.assertAlmostEqual(to_numpy(a)[0, 0], math.log(12.34), 5)
+        t += 12.34
+        a = tensor.log(t)
+        self.assertAlmostEqual(tensor.to_numpy(a)[0, 0], math.log(12.34))
+
+    def test_random(self):
+        x = tensor.Tensor((1000,))
+        x.gaussian(1, 0.01)
+        self.assertAlmostEqual(tensor.average(x), 1, 3)
+
 
 if __name__ == '__main__':
     unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/790b7b4c/test/singa/test_accuracy.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_accuracy.cc b/test/singa/test_accuracy.cc
index 4ff14c0..5d337fb 100644
--- a/test/singa/test_accuracy.cc
+++ b/test/singa/test_accuracy.cc
@@ -29,7 +29,7 @@ TEST(Accuracy, Compute) {
   const float pdat[6] = {0.1, 0.3, 0.6, 0.3, 0.2, 0.5};
   const int tdat[2] = {1, 2};  // one wrong, one correct
   p.CopyDataFromHostPtr(pdat, sizeof(pdat) / sizeof(float));
-  t.CopyDataFromHostPtr(tdat, sizeof(pdat) / sizeof(float));
+  t.CopyDataFromHostPtr(tdat, sizeof(tdat) / sizeof(int));
   float a = acc.Evaluate(p, t);
   EXPECT_FLOAT_EQ(a, 0.5f);
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/790b7b4c/test/singa/test_adagrad.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_adagrad.cc b/test/singa/test_adagrad.cc
index e8cd062..f12ec68 100644
--- a/test/singa/test_adagrad.cc
+++ b/test/singa/test_adagrad.cc
@@ -36,7 +36,7 @@ TEST(AdaGrad, ApplyCPU) {
 
   singa::OptimizerConf conf;
   adagrad.Setup(conf);
-  adagrad.Apply(0, lr, "xx", grad, &value);
+  adagrad.Apply(0, lr, "xx", grad, value);
 
   singa::Tensor v1 = value.Clone();
   const float* newv1 = v1.data<float>();
@@ -47,7 +47,7 @@ TEST(AdaGrad, ApplyCPU) {
                 1e-5);
 
   grad.CopyDataFromHostPtr(g, 4);
-  adagrad.Apply(1, lr, "xx", grad, &value);
+  adagrad.Apply(1, lr, "xx", grad, value);
   singa::Tensor v2 = value.Clone();
   const float* newv2 = v2.data<float>();
   for (int i = 0; i < 4; ++i) history[i] += g[i] * g[i];
@@ -71,7 +71,7 @@ TEST(AdaGrad, ApplyCUDA) {
 
   singa::OptimizerConf conf;
   adagrad.Setup(conf);
-  adagrad.Apply(0, lr, "xx", grad, &value);
+  adagrad.Apply(0, lr, "xx", grad, value);
 
   singa::Tensor v1 = value.Clone();
   v1.ToHost();
@@ -83,7 +83,7 @@ TEST(AdaGrad, ApplyCUDA) {
                 1e-5);
 
   grad.CopyDataFromHostPtr(g, 4);
-  adagrad.Apply(1, lr, "xx", grad, &value);
+  adagrad.Apply(1, lr, "xx", grad, value);
   singa::Tensor v2 = value.Clone();
   v2.ToHost();
   const float* newv2 = v2.data<float>();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/790b7b4c/test/singa/test_initializer.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_initializer.cc b/test/singa/test_initializer.cc
index 4631af2..74a30bb 100644
--- a/test/singa/test_initializer.cc
+++ b/test/singa/test_initializer.cc
@@ -26,7 +26,7 @@ TEST(Initializer, Constant) {
   singa::FillerConf conf;
   conf.set_value(3.1f);
   x.Setup(conf);
-  x.Fill(&t);
+  x.Fill(t);
   const float* xPtr = t.data<float>();
   for (size_t i = 0; i < n; i++)
     EXPECT_FLOAT_EQ(xPtr[i], 3.1f);
@@ -41,7 +41,7 @@ TEST(Initializer, Gaussian) {
   conf.set_mean(0.11f);
   conf.set_std(0.01f);
   x.Setup(conf);
-  x.Fill(&t);
+  x.Fill(t);
   const float* xPtr = t.data<float>();
   float mean = 0.0f, std = 0.0f;
   for (size_t i = 0; i < n; i++)
@@ -64,7 +64,7 @@ TEST(Initializer, ConstantCUDA) {
   singa::FillerConf conf;
   conf.set_value(3.1f);
   x.Setup(conf);
-  x.Fill(&t);
+  x.Fill(t);
   t.ToHost();
   const float* xPtr = t.data<float>();
   for (size_t i = 0; i < n; i++)
@@ -73,7 +73,7 @@ TEST(Initializer, ConstantCUDA) {
 
   singa::init::Constant y(-0.1f);
   singa::Tensor s(singa::Shape{n}, dev);
-  y.Fill(&s);
+  y.Fill(s);
   s.ToHost();
   const float* sPtr = s.data<float>();
   for (size_t i = 0; i < n; i++)
@@ -90,7 +90,7 @@ TEST(Initializer, GaussianCUDA) {
   conf.set_mean(0.11f);
   conf.set_std(0.01f);
   x.Setup(conf);
-  x.Fill(&t);
+  x.Fill(t);
   t.ToHost();
   const float* tPtr = t.data<float>();
   float mean = 0.0f, std = 0.0f;
@@ -107,7 +107,7 @@ TEST(Initializer, GaussianCUDA) {
 
   singa::init::Gaussian y(1.5f, 0.1f);
   singa::Tensor s(singa::Shape{n}, dev);
-  y.Fill(&s);
+  y.Fill(s);
   s.ToHost();
   const float* sPtr = s.data<float>();
   for (size_t i = 0; i < n; i++)
@@ -126,7 +126,7 @@ TEST(Initializer, XavierCUDA) {
   auto dev = std::make_shared<singa::CudaGPU>();
   size_t m = 30, n=40;
   singa::Tensor t(singa::Shape{m, n}, dev);
-  x.Fill(&t);
+  x.Fill(t);
   t.ToHost();
   const float* xPtr = t.data<float>();
   float mean = 0.0f;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/790b7b4c/test/singa/test_layer.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_layer.cc b/test/singa/test_layer.cc
index 7306f39..4071762 100644
--- a/test/singa/test_layer.cc
+++ b/test/singa/test_layer.cc
@@ -1,6 +1,29 @@
-#include "singa/proto/core.pb.h"
 #include "gtest/gtest.h"
+#include "singa/model/layer.h"
+#include "singa/singa_config.h"
 
+TEST(Layer, CreateLayer) {
+  std::vector<std::string> types{
+      "Convolution", "Dense", "Dropout", "Activation", "BatchNorm",
+      "Flatten",     "LRN",   "Pooling", "PReLU",      "Softmax"};
+  for (auto type : types) {
+    auto layer = singa::CreateLayer(type);
+    EXPECT_EQ(layer->layer_type(), type);
+  }
+}
 
-TEST(TestProto, CopyMsgSameFields) {
+#ifdef USE_CUDNN
+TEST(Layer, CreateCudnnLayer) {
+  std::vector<std::string> types{
+      "CudnnConvolution", "CudnnActivation",
+      "CudnnBatchNorm",   "Flatten",      "CudnnLRN",
+      "CudnnPooling",     "PReLU",        "CudnnSoftmax"};
+#if CUDNN_VERSION_MAJOR >= 5
+  types.push_back("CudnnDropout");
+#endif
+  for (auto type : types) {
+    auto layer = singa::CreateLayer(type);
+    EXPECT_EQ(layer->layer_type(), type);
+  }
 }
+#endif

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/790b7b4c/test/singa/test_nesterov.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_nesterov.cc b/test/singa/test_nesterov.cc
index 73f69f4..7c76784 100644
--- a/test/singa/test_nesterov.cc
+++ b/test/singa/test_nesterov.cc
@@ -35,7 +35,7 @@ TEST(Nesterov, ApplyCPU) {
   value.CopyDataFromHostPtr(v, 4);
   grad.CopyDataFromHostPtr(g, 4);
 
-  nesterov.Apply(0, lr, "xx", grad, &value);
+  nesterov.Apply(0, lr, "xx", grad, value);
 
   singa::Tensor v1 = value.Clone();
   const float* newv1 = v1.data<float>();
@@ -47,7 +47,7 @@ TEST(Nesterov, ApplyCPU) {
   for (int i = 0; i < 4; ++i) EXPECT_FLOAT_EQ(newv1[i], v[i] - tmp[i]);
 
   grad.CopyDataFromHostPtr(g, 4);
-  nesterov.Apply(1, lr, "xx", grad, &value);
+  nesterov.Apply(1, lr, "xx", grad, value);
   singa::Tensor v2 = value.Clone();
   const float* newv2 = v2.data<float>();
   for (int i = 0; i < 4; ++i) {
@@ -73,7 +73,7 @@ TEST(Nesterov, ApplyCUDA) {
   value.CopyDataFromHostPtr(v, 4);
   grad.CopyDataFromHostPtr(g, 4);
 
-  nesterov.Apply(0, lr, "xx", grad, &value);
+  nesterov.Apply(0, lr, "xx", grad, value);
 
   singa::Tensor v1 = value.Clone();
   v1.ToHost();
@@ -86,7 +86,7 @@ TEST(Nesterov, ApplyCUDA) {
   for (int i = 0; i < 4; ++i) EXPECT_FLOAT_EQ(newv1[i], v[i] - tmp[i]);
 
   grad.CopyDataFromHostPtr(g, 4);
-  nesterov.Apply(1, lr, "xx", grad, &value);
+  nesterov.Apply(1, lr, "xx", grad, value);
   singa::Tensor v2 = value.Clone();
   v2.ToHost();
   const float* newv2 = v2.data<float>();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/790b7b4c/test/singa/test_rmsprop.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_rmsprop.cc b/test/singa/test_rmsprop.cc
index 18de9c3..d259592 100644
--- a/test/singa/test_rmsprop.cc
+++ b/test/singa/test_rmsprop.cc
@@ -39,7 +39,7 @@ TEST(RMSProp, ApplyCPU) {
   grad.CopyDataFromHostPtr(g, 4);
 
   rmsprop.Setup(conf);
-  rmsprop.Apply(0, lr, "xx", grad, &value);
+  rmsprop.Apply(0, lr, "xx", grad, value);
 
   singa::Tensor v1 = value.Clone();
   const float* newv1 = v1.data<float>();
@@ -50,7 +50,7 @@ TEST(RMSProp, ApplyCPU) {
                 1e-5);
 
   grad.CopyDataFromHostPtr(g, 4);
-  rmsprop.Apply(1, lr, "xx", grad, &value);
+  rmsprop.Apply(1, lr, "xx", grad, value);
   singa::Tensor v2 = value.Clone();
   const float* newv2 = v2.data<float>();
   for (int i = 0; i < 4; ++i)
@@ -79,7 +79,7 @@ TEST(RMSProp, ApplyCUDA) {
   grad.CopyDataFromHostPtr(g, 4);
 
   rmsprop.Setup(conf);
-  rmsprop.Apply(0, lr, "xx", grad, &value);
+  rmsprop.Apply(0, lr, "xx", grad, value);
 
   singa::Tensor v1 = value.Clone();
   v1.ToHost();
@@ -91,7 +91,7 @@ TEST(RMSProp, ApplyCUDA) {
                 1e-5);
 
   grad.CopyDataFromHostPtr(g, 4);
-  rmsprop.Apply(1, lr, "xx", grad, &value);
+  rmsprop.Apply(1, lr, "xx", grad, value);
   singa::Tensor v2 = value.Clone();
   v2.ToHost();
   const float* newv2 = v2.data<float>();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/790b7b4c/test/singa/test_sgd.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_sgd.cc b/test/singa/test_sgd.cc
index 8ea95b3..e6ed9bf 100644
--- a/test/singa/test_sgd.cc
+++ b/test/singa/test_sgd.cc
@@ -33,7 +33,7 @@ TEST(SGD, ApplyWithoutMomentum) {
   grad.CopyDataFromHostPtr(g, 4);
 
   float lr = 0.1f;
-  sgd.Apply(0, lr, "xx", grad, &value);
+  sgd.Apply(0, lr, "xx", grad, value);
 
   singa::Tensor v1 = value.Clone();
   const float* newv1 = v1.data<float>();
@@ -44,7 +44,7 @@ TEST(SGD, ApplyWithoutMomentum) {
 
   lr /= 2;
   grad.CopyDataFromHostPtr(g, 4);
-  sgd.Apply(1, lr, "xx", grad, &value);
+  sgd.Apply(1, lr, "xx", grad, value);
   singa::Tensor v2 = value.Clone();
   const float* newv2 = v2.data<float>();
   for (int i = 0; i < 4; i++) {
@@ -65,7 +65,7 @@ TEST(SGD, ApplyWithMomentum) {
   value.CopyDataFromHostPtr(v, 4);
   grad.CopyDataFromHostPtr(g, 4);
 
-  sgd.Apply(0, lr, "xx", grad, &value);
+  sgd.Apply(0, lr, "xx", grad, value);
 
   singa::Tensor v1 = value.Clone();
   const float* newv1 = v1.data<float>();
@@ -74,7 +74,7 @@ TEST(SGD, ApplyWithMomentum) {
   }
 
   grad.CopyDataFromHostPtr(g, 4);
-  sgd.Apply(1, lr, "xx", grad, &value);
+  sgd.Apply(1, lr, "xx", grad, value);
   singa::Tensor v2 = value.Clone();
   const float* newv2 = v2.data<float>();
   for (int i = 0; i < 4; i++) {
@@ -94,7 +94,7 @@ TEST(SGD, ApplyWithoutMomentumCuda) {
   grad.CopyDataFromHostPtr(g, 4);
 
   float lr = 0.1f;
-  sgd.Apply(0, lr, "xx", grad, &value);
+  sgd.Apply(0, lr, "xx", grad, value);
 
   singa::Tensor v1 = value.Clone();
   v1.ToHost();
@@ -106,7 +106,7 @@ TEST(SGD, ApplyWithoutMomentumCuda) {
 
   lr /= 2;
   grad.CopyDataFromHostPtr(g, 4);
-  sgd.Apply(1, lr, "xx", grad, &value);
+  sgd.Apply(1, lr, "xx", grad, value);
   singa::Tensor v2 = value.Clone();
   v2.ToHost();
   const float* newv2 = v2.data<float>();
@@ -129,7 +129,7 @@ TEST(SGD, ApplyWithMomentumCuda) {
   value.CopyDataFromHostPtr(v, 4);
   grad.CopyDataFromHostPtr(g, 4);
 
-  sgd.Apply(0, lr, "xx", grad, &value);
+  sgd.Apply(0, lr, "xx", grad, value);
 
   singa::Tensor v1 = value.Clone();
   v1.ToHost();
@@ -139,7 +139,7 @@ TEST(SGD, ApplyWithMomentumCuda) {
   }
 
   grad.CopyDataFromHostPtr(g, 4);
-  sgd.Apply(1, lr, "xx", grad, &value);
+  sgd.Apply(1, lr, "xx", grad, value);
   singa::Tensor v2 = value.Clone();
   v2.ToHost();
   const float* newv2 = v2.data<float>();

Reply via email to