SINGA-349 Create layer operations for autograd

1. fix bugs for the new design API

2. add flags for training or evaluation process.

3. add changeable initialization method


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/b136ac0a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/b136ac0a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/b136ac0a

Branch: refs/heads/master
Commit: b136ac0a8ce42fa6e5e123874c77039eaf86e556
Parents: 6402a53
Author: xuewanqi <36396136+xuewa...@users.noreply.github.com>
Authored: Mon May 7 15:31:29 2018 +0800
Committer: Wang Wei <dcs...@nus.edu.sg>
Committed: Thu May 17 21:19:07 2018 +0800

----------------------------------------------------------------------
 python/singa/layer_ops.py | 228 +++++++++++++++++++++++++++++++----------
 1 file changed, 172 insertions(+), 56 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b136ac0a/python/singa/layer_ops.py
----------------------------------------------------------------------
diff --git a/python/singa/layer_ops.py b/python/singa/layer_ops.py
index e5ef45f..dcbacf9 100644
--- a/python/singa/layer_ops.py
+++ b/python/singa/layer_ops.py
@@ -3,8 +3,8 @@ from singa import layer
 from singa.proto import model_pb2
 
 
-class Conv2D(tensor.Operation):
-    def __init__(self,in_channels, out_channels, kernel_size, stride=1, 
padding=0, dilation=1, groups=1, bias=True,**kwargs):
+class Conv2d(tensor.Operation):
+    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, 
padding=0, dilation=1, groups=1, bias=True, **kwargs):
 
         name='Conv2d'
         border_mode = 'same'
@@ -31,80 +31,148 @@ class Conv2D(tensor.Operation):
             else:
                 allowed_kwargs[kwarg] = kwargs[kwarg]
 
-        '''
-        How to match Keras:
+        self.W_specs=W_specs
+        self.b_specs=b_specs
+
+        if padding == 0:
+            pad = None
+        else:
+            pad = padding
+
+        if dilation != 1 or groups != 1:
+            raise ValueError('Not implemented yet')
 
-        in Keras conv2d, self.kernel record how to generate kernel 
(shape,initializer,name,regularizer,constraint),
-        it can be interpret to
-        shape -> kernel+input_sample_shape[0](nb_channels)+nb_kernels,
-        initializer, name, regularizer, constraint -> W_specs.
-        '''
         self.PyLayer = layer.Conv2D(name, nb_kernels=out_channels, 
kernel=kernel_size, stride=stride, border_mode=border_mode,
                  cudnn_prefer=cudnn_prefer, 
workspace_byte_limit=workspace_byte_limit,
-                 data_format=data_format, use_bias=bias, W_specs=W_specs, 
b_specs=b_specs,
-                 pad=padding, input_sample_shape=input_sample_shape)
+                 data_format=data_format, use_bias=bias, W_specs=self.W_specs, 
b_specs=self.b_specs,
+                 pad=pad, input_sample_shape=input_sample_shape)
 
+    def __call__(self, x, flag=True):
+        assert type(flag) is bool, 'flag can only be bool.'
+        if flag:
+            self.flag = model_pb2.kTrain
+        else:
+            self.flag = model_pb2.kEval
 
-    def __call__(self, x):
         if not self.PyLayer.has_setup:
             self.PyLayer.setup(x.shape[1:])
+
         param_data = self.PyLayer.layer.param_values()
         if not hasattr(self, 'w'):
             self.w = tensor.Tensor(data=param_data[0], requires_grad=True, 
stores_grad=True)
-            self.w.gaussian(0.0, 0.1)  # TODO realize other initialization 
method according to W_specs
-        
+            if self.W_specs['init'] == 'gaussian':
+                if 'std' not in self.W_specs or 'mean' not in self.W_specs:
+                    self.w.gaussian(0.0, 0.1)
+                else:
+                    self.w.gaussian(self.W_specs['mean'],self.W_specs['std'])
+            elif self.W_specs['init'] == 'uniform':
+                if 'low' not in self.W_specs or 'high' not in self.W_specs:
+                    self.w.uniform(0.0, 1.0)
+                else:
+                    self.w.uniform(self.W_specs['low'],self.W_specs['high'])
+            elif self.W_specs['init'] == 'xavier':
+                pass  # TODO
+
         xs = [x, self.w]
 
         if len(param_data) == 2:
-            self.b = tensor.Tensor(data=param_data[1], requires_grad=True, 
stores_grad=True)
-            self.b.set_value(0.0)  # TODO realize other initialization method 
according to b_specs
+            if not hasattr(self, 'b'):
+                self.b = tensor.Tensor(data=param_data[1], requires_grad=True, 
stores_grad=True)
+                if self.b_specs['init'] == 'gaussian':
+                    if 'std' not in self.b_specs or 'mean' not in self.b_specs:
+                        self.b.gaussian(0.0, 0.1)
+                    else:
+                        self.b.gaussian(self.b_specs['mean'], 
self.b_specs['std'])
+                elif self.b_specs['init'] == 'uniform':
+                    if 'low' not in self.b_specs or 'high' not in self.b_specs:
+                        self.b.uniform(0.0, 1.0)
+                    else:
+                        self.b.uniform(self.b_specs['low'], 
self.b_specs['high'])
+                elif self.b_specs['init'] == 'xavier':
+                    pass  # TODO
+                else:
+                    self.b.set_value(0.0)
+
             xs.append(self.b)
 
         xs = tuple(xs)
         return self._do_forward(*xs)
 
-    def forward(self, flag=True,*xs):
-        if flag is True:
-            return self.PyLayer.layer.Forward(4, xs[0])
-        else:
-            return self.PyLayer.layer.Forward(8, xs[0])
+    def forward(self, *xs):
+        return self.PyLayer.layer.Forward(self.flag, xs[0])
 
     def backward(self, dy):
         ret = self.PyLayer.layer.Backward(0, dy)
         return (ret[0],)+ret[1]
 
 
-class MaxPooling2D(tensor.Operation):
-    def __init__(self, name, kernel=3, stride=2, border_mode='same', pad=None,
-                 data_format='NCHW', input_sample_shape=None):
+class MaxPool2d(tensor.Operation):
+    def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1, 
return_indices=False, ceil_mode=False, **kwargs):
+
+        name = 'MaxPool2d'
+        border_mode = 'same'
+        data_format = 'NCHW'
+        input_sample_shape = None
+
+        allowed_kwargs = {'name': name,
+                          'border_mode': border_mode,
+                          'data_format': data_format,
+                          'input_sample_shape': input_sample_shape
+                          }
+
+        for kwarg in kwargs:
+            if kwarg not in allowed_kwargs:
+                raise TypeError('Keyword argument not understood:', kwarg)
+            else:
+                allowed_kwargs[kwarg] = kwargs[kwarg]
+
+        if padding == 0:
+            pad = None
+        else:
+            pad = padding
+
+        if dilation != 1 or return_indices is not False or ceil_mode is not 
False:
+            raise ValueError('Not implemented yet')
 
         self.PyLayer = layer.Pooling2D(name, model_pb2.PoolingConf.MAX,
-                                           kernel, stride, border_mode,
+                                           kernel_size, stride, border_mode,
                                            pad, data_format, 
input_sample_shape)
 
-    def __call__(self, x):
+    def __call__(self, x, flag=True):
+        assert type(flag) is bool, 'flag can only be bool.'
+        if flag:
+            self.flag = model_pb2.kTrain
+        else:
+            self.flag = model_pb2.kEval
+
         if not self.PyLayer.has_setup:
             self.PyLayer.setup(x.shape[1:])
+
         return self._do_forward(x)
 
-    def forward(self, x):
-        return self.PyLayer.layer.Forward(4, x)
+    def forward(self, *xs):
+        return self.PyLayer.layer.Forward(self.flag, xs[0])
 
     def backward(self, dy):
         return self.PyLayer.layer.Backward(0, dy)[0]
 
 
-class Activation(tensor.Operation):
-    def __init__(self,name, mode='relu',input_sample_shape=None):
+class ReLU(tensor.Operation):
+    def __init__(self, name='ReLU', mode='relu',input_sample_shape=None):
         self.PyLayer = layer.Activation(name, mode, input_sample_shape)
 
-    def __call__(self, x):
+    def __call__(self, x, flag=True):
+        assert type(flag) is bool, 'flag can only be bool.'
+        if flag:
+            self.flag = model_pb2.kTrain
+        else:
+            self.flag = model_pb2.kEval
         if not self.PyLayer.has_setup:
             self.PyLayer.setup(x.shape[1:])
         return self._do_forward(x)
 
-    def forward(self, x):
-        return self.PyLayer.layer.Forward(4, x)
+    def forward(self, flag=True, *xs):
+        return self.PyLayer.layer.Forward(self.flag, xs[0])
 
     def backward(self, dy):
         return self.PyLayer.layer.Backward(0, dy)[0]
@@ -114,58 +182,106 @@ class Flatten(tensor.Operation):
     def __init__(self, name, axis=1, input_sample_shape=None):
         self.PyLayer = layer.Flatten(name, axis, input_sample_shape)
 
-    def __call__(self, x):
+    def __call__(self, x, flag=True):
+        assert type(flag) is bool, 'flag can only be bool.'
+        if flag:
+            self.flag = model_pb2.kTrain
+        else:
+            self.flag = model_pb2.kEval
         if not self.PyLayer.has_setup:
             self.PyLayer.setup(x.shape[1:])
         return self._do_forward(x)
 
-    def forward(self, x):
-        return self.PyLayer.layer.Forward(4, x)
+    def forward(self, *xs):
+        return self.PyLayer.layer.Forward(self.flag, xs[0])
 
     def backward(self, dy):
         return self.PyLayer.layer.Backward(0, dy)[0]
 
 
-class Dense(tensor.Operation):
-    def __init__(self, name, num_output, use_bias=True,
-                     W_specs=None, b_specs=None,
-                     W_transpose=False, input_sample_shape=None):
+class Linear(tensor.Operation):
+    def __init__(self, in_features, out_features, bias=True, **kwargs):
+
+        name = 'Linear'
+        W_transpose=False
+        W_specs = None
+        b_specs = None
+        input_sample_shape = in_features
+
+        allowed_kwargs = {'name': name,
+                          'W_transpose': W_transpose,
+                          'W_specs': W_specs,
+                          'b_specs': b_specs,
+                          'input_sample_shape': input_sample_shape
+                          }
+
+        for kwarg in kwargs:
+            if kwarg not in allowed_kwargs:
+                raise TypeError('Keyword argument not understood:', kwarg)
+            else:
+                allowed_kwargs[kwarg] = kwargs[kwarg]
 
-        self.PyLayer = layer.Dense(name, num_output=num_output, 
use_bias=use_bias,
-                     W_specs=W_specs, b_specs=b_specs,
+        self.W_specs = W_specs
+        self.b_specs = b_specs
+
+        self.PyLayer = layer.Dense(name, num_output=out_features, 
use_bias=bias,
+                     W_specs=self.W_specs, b_specs=self.b_specs,
                      W_transpose=W_transpose, 
input_sample_shape=input_sample_shape)
 
-    def __call__(self, x):
+    def __call__(self, x, flag=True):
+        assert type(flag) is bool, 'flag can only be bool.'
+        if flag:
+            self.flag = model_pb2.kTrain
+        else:
+            self.flag = model_pb2.kEval
+
         if not self.PyLayer.has_setup:
             self.PyLayer.setup(x.shape[1:])
 
         param_data = self.PyLayer.layer.param_values()
-
         if not hasattr(self, 'w'):
             self.w = tensor.Tensor(data=param_data[0], requires_grad=True, 
stores_grad=True)
-            self.w.gaussian(0.0, 0.1)  # TODO realize other initialization 
method according to W_specs
+            if self.W_specs['init'] == 'gaussian':
+                if 'std' not in self.W_specs or 'mean' not in self.W_specs:
+                    self.w.gaussian(0.0, 0.1)
+                else:
+                    self.w.gaussian(self.W_specs['mean'],self.W_specs['std'])
+            elif self.W_specs['init'] == 'uniform':
+                if 'low' not in self.W_specs or 'high' not in self.W_specs:
+                    self.w.uniform(0.0, 1.0)
+                else:
+                    self.w.uniform(self.W_specs['low'],self.W_specs['high'])
+            elif self.W_specs['init'] == 'xavier':
+                pass  # TODO
 
         xs = [x, self.w]
 
         if len(param_data) == 2:
-            self.b = tensor.Tensor(data=param_data[1], requires_grad=True, 
stores_grad=True)
-            self.b.set_value(0.0)  # TODO realize other initialization method 
according to b_specs
+            if not hasattr(self, 'b'):
+                self.b = tensor.Tensor(data=param_data[1], requires_grad=True, 
stores_grad=True)
+                if self.b_specs['init'] == 'gaussian':
+                    if 'std' not in self.b_specs or 'mean' not in self.b_specs:
+                        self.b.gaussian(0.0, 0.1)
+                    else:
+                        self.b.gaussian(self.b_specs['mean'], 
self.b_specs['std'])
+                elif self.b_specs['init'] == 'uniform':
+                    if 'low' not in self.b_specs or 'high' not in self.b_specs:
+                        self.b.uniform(0.0, 1.0)
+                    else:
+                        self.b.uniform(self.b_specs['low'], 
self.b_specs['high'])
+                elif self.b_specs['init'] == 'xavier':
+                    pass  # TODO
+                else:
+                    self.b.set_value(0.0)
+
             xs.append(self.b)
 
         xs = tuple(xs)
         return self._do_forward(*xs)
 
     def forward(self, *xs):
-        return self.PyLayer.layer.Forward(4, xs[0])
+        return self.PyLayer.layer.Forward(self.flag, xs[0])
 
     def backward(self, dy):
         ret = self.PyLayer.layer.Backward(0, dy)
         return (ret[0],)+ret[1]
-
-
-
-
-
-
-
-

Reply via email to