Merge commits for debugging the gradient average error and commits for 
documentation.

Conflicts:
        examples/cifar10/alexnet.py
        src/python/singa/layer.py
        src/python/singa/optimizer.py


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/5d20d353
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/5d20d353
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/5d20d353

Branch: refs/heads/master
Commit: 5d20d353bd09f2bd758f27f5e1851af7ae8d4123
Parents: 5db7eb6 6d4539e
Author: Wei Wang <[email protected]>
Authored: Mon Aug 15 20:31:21 2016 +0800
Committer: Wei Wang <[email protected]>
Committed: Mon Aug 15 20:31:21 2016 +0800

----------------------------------------------------------------------
 examples/cifar10/alexnet.cc   | 11 +++-----
 examples/cifar10/alexnet.py   | 53 +++++++-------------------------------
 examples/cifar10/train.py     | 19 +++++++-------
 src/model/feed_forward_net.cc |  6 ++---
 src/model/optimizer/sgd.cc    |  4 +--
 src/python/singa/layer.py     | 30 +++++++++++++++++----
 src/python/singa/net.py       |  8 +++++-
 src/python/singa/optimizer.py | 29 +++++++++++----------
 src/python/singa/tensor.py    |  8 +++---
 9 files changed, 80 insertions(+), 88 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5d20d353/examples/cifar10/alexnet.cc
----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5d20d353/examples/cifar10/alexnet.py
----------------------------------------------------------------------
diff --cc examples/cifar10/alexnet.py
index 17b6a89,dae129f..02437b3
--- a/examples/cifar10/alexnet.py
+++ b/examples/cifar10/alexnet.py
@@@ -35,54 -36,20 +35,21 @@@ def create_net(use_cpu=False)
      W0_specs = {'init': 'gaussian', 'mean': 0, 'std': 0.0001}
      W1_specs = {'init': 'gaussian', 'mean': 0, 'std': 0.01}
      W2_specs = {'init': 'gaussian', 'mean': 0, 'std': 0.01, 'decay_mult': 250}
-     b_specs = {'init': 'constant', 'value': 0, 'lt_mult': 2}
-     net.add(
-         layer.Conv2D(
-             'conv1',
-             32,
-             5,
-             1,
-             W_specs=W0_specs.copy(),
-             b_specs=b_specs.copy(),
-             pad=2,
-             input_sample_shape=(
-                 3,
-                 32,
-                 32,
-             )))
++
+     b_specs = {'init': 'constant', 'value': 0, 'lr_mult': 2, 'decay_mult': 0}
+     net.add(layer.Conv2D('conv1', 32, 5, 1, W_specs=W0_specs.copy(), 
b_specs=b_specs.copy(), pad=2, input_sample_shape=(3,32,32,)))
      net.add(layer.MaxPooling2D('pool1', 3, 2, pad=1))
      net.add(layer.Activation('relu1'))
-     net.add(layer.LRN(name='lrn1'))
-     net.add(
-         layer.Conv2D(
-             'conv2',
-             32,
-             5,
-             1,
-             W_specs=W1_specs.copy(),
-             b_specs=b_specs.copy(),
-          pad=2))
+     net.add(layer.LRN(name='lrn1', size=3, alpha=5e-5))
+     net.add(layer.Conv2D('conv2', 32, 5, 1, W_specs=W1_specs.copy(), 
b_specs=b_specs.copy(), pad=2))
      net.add(layer.Activation('relu2'))
-     net.add(layer.MaxPooling2D('pool2', 3, 2,  pad=1))
-     net.add(layer.LRN('lrn2'))
-     net.add(
-         layer.Conv2D(
-             'conv3',
-             64,
-             5,
-             1,
-             W_specs=W1_specs.copy(),
-             b_specs=b_specs.copy(),
-          pad=2))
+     net.add(layer.AvgPooling2D('pool2', 3, 2,  pad=1))
+     net.add(layer.LRN('lrn2', size=3, alpha=5e-5))
+     net.add(layer.Conv2D('conv3', 64, 5, 1, W_specs=W1_specs.copy(), 
b_specs=b_specs.copy(), pad=2))
      net.add(layer.Activation('relu3'))
-     net.add(layer.MaxPooling2D('pool3', 3, 2, pad=1))
+     net.add(layer.AvgPooling2D('pool3', 3, 2, pad=1))
      net.add(layer.Flatten('flat'))
-     net.add(
-         layer.Dense(
-             'dense',
-             10,
-             W_specs=W2_specs.copy(),
-          b_specs=b_specs.copy()))
 -    net.add(layer.Dense('dense', 10, W_specs=W2_specs.copy(), 
b_specs=b_specs.copy()))
++    net.add(layer.Dense( 'dense', 10, W_specs=W2_specs.copy(), 
b_specs=b_specs.copy()))
      for (p, specs) in zip(net.param_values(), net.param_specs()):
          filler = specs.filler
          if filler.type == 'gaussian':

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5d20d353/src/python/singa/layer.py
----------------------------------------------------------------------
diff --cc src/python/singa/layer.py
index b0fdb5e,1e9caeb..a9f3826
--- a/src/python/singa/layer.py
+++ b/src/python/singa/layer.py
@@@ -473,28 -388,29 +473,47 @@@ class LRN(Layer)
  
  
  class Dense(Layer):
 +    """Apply linear/affine transformation, also called inner-product or
 +    fully connected layer.
  
 +    Args:
 +        num_output (int): output feature length.
 +        use_bias (bool): add a bias vector or not to the transformed feature
 +        W_specs (dict): specs for the weight matrix
 +            'name' for parameter name
 +            'lr_mult' for learning rate multiplier
 +            'decay_mult' for weight decay multiplier
 +            'init' for init method, which could be 'gaussian', 'uniform',
 +            'xavier' and ''
 +            'std', 'mean', 'high', 'low' for corresponding init methods
 +            'clamp' for gradient constraint, value is scalar
 +            'regularizer' for regularization, currently support 'l2'
 +        b_specs (dict): specs for the bias vector, same fields as W_specs.
 +        W_transpose (bool): if true, output=x*W.T+b;
 +        input_sample_shape (tuple): input feature length
 +    """
      def __init__(self, name, num_output, use_bias=True,
                   W_specs=None, b_specs=None,
-                  W_transpose=True, input_sample_shape=None):
+                  W_transpose=False, input_sample_shape=None):
+         """Apply linear/affine transformation, also called inner-product or
+         fully connected layer.
+ 
+         Args:
+             num_output (int): output feature length.
+             use_bias (bool): add a bias vector or not to the transformed 
feature
+             W_specs (dict): specs for the weight matrix
+                 'name' for parameter name
+                 'lr_mult' for learning rate multiplier
+                 'decay_mult' for weight decay multiplier
+                 'init' for init method, which could be 'gaussian', 'uniform',
+                 'xavier' and ''
+                 'std', 'mean', 'high', 'low' for corresponding init methods
+                 'clamp' for gradient constraint, value is scalar
+                 'regularizer' for regularization, currently support 'l2'
+             b_specs (dict): specs for the bias vector, same fields as W_specs.
+             W_transpose (bool): if true, output=x*W.T+b;
+             input_sample_shape (tuple): input feature length
+         """
          super(Dense, self).__init__(name)
          conf = self.conf.dense_conf
          conf.num_output = num_output
@@@ -508,15 -424,12 +527,15 @@@
              W_specs['name'] = name + '_weight'
          if 'name' not in b_specs:
              b_specs['name'] = name + '_bias'
-         self.conf.param.extend([_construct_param_specs_from_dict(W_specs)])
-         self.param_specs.append(_construct_param_specs_from_dict(W_specs))
-         self.conf.param.extend([_construct_param_specs_from_dict(b_specs)])
-         self.param_specs.append(_construct_param_specs_from_dict(b_specs))
+         wspecs = _construct_param_specs_from_dict(W_specs)
+         bspecs = _construct_param_specs_from_dict(b_specs)
+         self.conf.param.extend([wspecs, bspecs])
+         self.param_specs.extend([wspecs, bspecs])
          # dense layer is transparent to engine.
 -        self.layer = _create_layer('singa', 'Dense')
 +        if engine == 'cudnn':
 +            self.layer = _create_layer('singacuda', 'Dense')
 +        else:
 +            self.layer = _create_layer(engine, 'Dense')
          if input_sample_shape is not None:
              self.setup(input_sample_shape)
  

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5d20d353/src/python/singa/optimizer.py
----------------------------------------------------------------------
diff --cc src/python/singa/optimizer.py
index 14cf3c0,32f03d4..74e6ade
--- a/src/python/singa/optimizer.py
+++ b/src/python/singa/optimizer.py
@@@ -107,23 -100,24 +107,25 @@@ class Optimizer(object)
  
          Args:
              name (str): parameter name
 -            specs (ParamSpec): protobuf obj
 -        """
 -        assert type(specs) == model_pb2.ParamSpec, \
 +            specs (ParamSpec): protobuf obj, including regularizer and
 +                constraint, multipliers for learning rate and weight decay.
- 
 +        '''
 +        assert isinstance(specs, model_pb2.ParamSpec), \
              'specs should be model_pb2.ParamSpec instance'
          if specs.HasField('regularizer'):
              self.regularizers[name] = CppRegularizer(specs.regularizer)
+         elif specs.decay_mult != 1:
+             self.regularizers[name] = L2Regularizer(
+                 specs.decay_mult * self.regularizer.coefficient)
+ 
          if specs.HasField('constraint'):
              self.constraints[name] = CppConstraint(specs.constraint)
+ 
          if specs.lr_mult != 1:
              self.learning_rate_multiplier[name] = specs.lr_mult
-         if specs.decay_mult != 1:
-             self.decay_multiplier[name] = specs.decay_mult
  
 -    def apply_regularizer_constraint(self, value, grad, name=None, step=None):
 -        """Apply regularization and constraint if available.
 +    def apply_regularizer_constraint(self, epoch, value, grad, name=None):
 +        '''Apply regularization and constraint if available.
  
          If there are both global regularizer (constraint) and param specific
          regularizer (constraint), it would use the param specific one.
@@@ -189,32 -184,24 +191,27 @@@
  
  
  class SGD(Optimizer):
 +    '''The vallina Stochasitc Gradient Descent algorithm with momentum.
  
 -    def __init__(self, lr=None, momentum=None, decay=None):
 -        """The vallina Stochasitc Gradient Descent algorithm.
 +    See the base Optimizer for all arguments.
 +    '''
  
 -        See the base Optimizer for all arguments.
 -        """
 -        super(SGD, self).__init__(lr, momentum, decay)
 +    def __init__(self, lr=None, momentum=None, weight_decay=None, lr_gen=None,
 +                 regularizer=None, constraint=None):
-         super(
-             SGD,
-             self).__init__(
-             lr,
-             momentum,
-             weight_decay,
-             lr_gen,
-             regularizer,
-          constraint)
++        super(SGD, self).__init__(lr, momentum, weight_decay, lr_gen,
++                                  regularizer, constraint)
          conf = model_pb2.OptimizerConf()
 -        if momentum is not None:
 -            conf.momentum = momentum
 +        if self.momentum is not None:
 +            conf.momentum = self.momentum
 +        conf.type = 'sgd'
          self.opt = singa.CreateOptimizer('SGD')
          self.opt.Setup(conf.SerializeToString())
  
 -    def apply_with_lr(self, step, lr, grad, value, name):
 -        self.apply_regularizer_constraint(value, grad, name, step)
 +    def apply_with_lr(self, epoch, lr, grad, value, name):
 +        self.apply_regularizer_constraint(epoch, value, grad, name)
+         if name is not None and name in self.learning_rate_multiplier:
+             lr = lr * self.learning_rate_multiplier[name]
 -        self.opt.Apply(step, lr, name, grad.singa_tensor, value.singa_tensor)
 +        self.opt.Apply(epoch, lr, name, grad.singa_tensor, value.singa_tensor)
          return value
  
  
@@@ -260,9 -240,11 +257,11 @@@ class AdaGrad(Optimizer)
          self.opt = singa.CreateOptimizer('AdaGrad')
          self.opt.Setup(conf.SerializeToString())
  
 -    def apply_with_lr(self, step, lr, grad, value, name):
 -        grad = self.apply_regularizer_constraint(step, value, grad, name)
 +    def apply_with_lr(self, epoch, lr, grad, value, name):
 +        grad = self.apply_regularizer_constraint(epoch, value, grad, name)
+         if name is not None and name in self.learning_rate_multiplier:
+             lr = lr * self.learning_rate_multiplier[name]
 -        self.opt.Apply(step, lr,  name, grad.singa_tensor, value.singa_tensor)
 +        self.opt.Apply(epoch, lr,  name, grad.singa_tensor, 
value.singa_tensor)
          return value
  
  
@@@ -286,9 -265,11 +285,11 @@@ class RMSProp(Optimizer)
          self.opt = singa.CreateOptimizer('RMSProp')
          self.opt.Setup(conf.SerializeToString())
  
 -    def apply_with_lr(self, step, lr, grad, value, name):
 -        grad = self.apply_regularizer_constraint(step, value, grad, name)
 +    def apply_with_lr(self, epoch, lr, grad, value, name):
 +        grad = self.apply_regularizer_constraint(epoch, value, grad, name)
+         if name is not None and name in self.learning_rate_multiplier:
+             lr = lr * self.learning_rate_multiplier[name]
 -        self.opt.Apply(step, lr,  name, grad.singa_tensor, value.singa_tensor)
 +        self.opt.Apply(epoch, lr,  name, grad.singa_tensor, 
value.singa_tensor)
          return value
  
  

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/5d20d353/src/python/singa/tensor.py
----------------------------------------------------------------------

Reply via email to