Repository: incubator-singa
Updated Branches:
  refs/heads/master 163452e74 -> 060e7dfe1


SINGA-348 Support autograd MLP Example

Fix some bugs and rename some files, functions and variables.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/755eba69
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/755eba69
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/755eba69

Branch: refs/heads/master
Commit: 755eba69a9b010b4eb4ab5c9277fc6487998d55a
Parents: e09dff4
Author: xuewanqi <36396136+xuewa...@users.noreply.github.com>
Authored: Thu Apr 5 18:45:32 2018 +0800
Committer: Wang Wei <dcs...@nus.edu.sg>
Committed: Thu Apr 12 16:59:47 2018 +0800

----------------------------------------------------------------------
 examples/MLP.py        |  82 ++++++++--------
 python/singa/engine.py | 141 +++++++++++++++++++++++++++
 python/singa/tensor.py | 226 ++++++++++++++++++++++++++++++++++++++++++--
 3 files changed, 402 insertions(+), 47 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/755eba69/examples/MLP.py
----------------------------------------------------------------------
diff --git a/examples/MLP.py b/examples/MLP.py
index 54ae1ad..405d998 100644
--- a/examples/MLP.py
+++ b/examples/MLP.py
@@ -4,14 +4,11 @@ from singa import engine
 from singa import singa_wrap as singa
 import numpy as np
 
-def print_singa_tensor(x):
-    np_array = x.GetFloatValue(int(x.Size()))
-    print(np_array.reshape(x.shape()))
-    return
 
-if __name__ =='__main__':
+if __name__ == '__main__':
+
+    # prepare training data in numpy array
 
-    #prepare numpy training data
     # generate the boundary
     f = lambda x: (5 * x + 1)
     bd_x = np.linspace(-1., 1, 200)
@@ -24,18 +21,17 @@ if __name__ =='__main__':
     data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np.float32)
 
     def to_categorical(y, num_classes=None):
-        """Converts a class vector (integers) to binary class matrix.
-
-        E.g. for use with categorical_crossentropy.
+        '''
+        Converts a class vector (integers) to binary class matrix.
 
-        # Arguments
+        Args
             y: class vector to be converted into a matrix
                 (integers from 0 to num_classes).
             num_classes: total number of classes.
 
-        # Returns
+        Return
             A binary matrix representation of the input.
-        """
+        '''
         y = np.array(y, dtype='int')
         input_shape = y.shape
         if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
@@ -50,59 +46,65 @@ if __name__ =='__main__':
         categorical = np.reshape(categorical, output_shape)
         return categorical
 
-    label=to_categorical(label,2).astype(np.float32)
-    print 'train_data_shape:',data.shape,'train_label_shape:',label.shape
+    label = to_categorical(label,2).astype(np.float32)
+    print 'train_data_shape:', data.shape, 'train_label_shape:', label.shape
 
-    # send numpy data to singa_tensor
-    tr_data=singa.Tensor((400,2))
+    # send training data(numpy array) to singa_tensor
+    tr_data = singa.Tensor((400, 2))
     tr_data.CopyFloatDataFromHostPtr(data.flatten())
 
-    tr_label=singa.Tensor((400,2))
+    tr_label = singa.Tensor((400, 2))
     tr_label.CopyFloatDataFromHostPtr(label.flatten())
 
-    w_0=singa.Tensor((2,3))
+    w_0 = singa.Tensor((2, 3))
     singa.Gaussian(float(0), float(0.1), w_0)
-    b_0=singa.Tensor((1,3))
+    b_0 = singa.Tensor((1, 3))
     b_0.SetFloatValue(float(0))
 
-    w_1=singa.Tensor((3,2))
+    w_1 = singa.Tensor((3, 2))
     singa.Gaussian(float(0), float(0.1), w_1)
-    b_1=singa.Tensor((1,2))
+    b_1 = singa.Tensor((1, 2))
     b_1.SetFloatValue(float(0))
 
+    # initialize tensor.Tensor using singa_tensor
+    inputs = tensor.Tensor(data=tr_data, requires_grad=False, 
grad_outlet=False)
+    target = tensor.Tensor(data=tr_label, requires_grad=False, 
grad_outlet=False)
 
-    # initialize Tensor using singa_tensor
-    inputs=tensor.Tensor(data=tr_data,requires_grad=False,grad_outlet=False)
-    target=tensor.Tensor(data=tr_label,requires_grad=False,grad_outlet=False)
-
-    weight_0=tensor.Tensor(data=w_0,requires_grad=True,grad_outlet=True)
-    bias_0=tensor.Tensor(data=b_0,requires_grad=True,grad_outlet=True)
+    weight_0 = tensor.Tensor(data=w_0, requires_grad=True, grad_outlet=True)
+    bias_0 = tensor.Tensor(data=b_0, requires_grad=True, grad_outlet=True)
 
-    weight_1=tensor.Tensor(data=w_1,requires_grad=True,grad_outlet=True)
-    bias_1=tensor.Tensor(data=b_1,requires_grad=True,grad_outlet=True)
+    weight_1 = tensor.Tensor(data=w_1, requires_grad=True, grad_outlet=True)
+    bias_1 = tensor.Tensor(data=b_1, requires_grad=True, grad_outlet=True)
 
-    def update(lr,param,grad): #param:Tensor grad:singa_tensor
+    def update(lr, param, grad):
+        '''
+        To update the value of parameters
+        Args:
+            param: tensor.Tensor
+            grad: singa_tensor
+        '''
         grad *= float(lr)
         assert param.singa_tensor.shape() == grad.shape()
-        param.singa_tensor = singa.__sub__(param.singa_tensor,grad)
+        param.singa_tensor = singa.__sub__(param.singa_tensor, grad)
         return
 
-    lr=0.05
+    # training process
+    lr = 0.05
     for i in range(1001):
-        outputs=tensor.dot(inputs,weight_0)
-        outputs=tensor.add_bias(bias_0,outputs)
-        outputs=tensor.relu(outputs)
+        outputs = tensor.dot(inputs, weight_0)
+        outputs = tensor.add_bias(bias_0, outputs)
+        outputs = tensor.relu(outputs)
         outputs = tensor.dot(outputs, weight_1)
         outputs = tensor.add_bias(bias_1, outputs)
-        outputs=tensor.softmax(outputs)
+        outputs = tensor.softmax(outputs)
 
-        loss=tensor.cross_entropy(outputs,target)
+        loss = tensor.cross_entropy(outputs, target)
 
-        grads=float(1)
+        grads = float(1)
         in_grads = engine.gradients(loss, grads)
 
         for param in in_grads:
-            update(lr,param,in_grads[param])
+            update(lr, param, in_grads[param])
 
         if (i % 100 == 0):
-            print 'training loss = ' ,float(tensor.To_numpy(loss.singa_tensor))
\ No newline at end of file
+            print 'training loss = ', float(tensor.To_Numpy(loss.singa_tensor))
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/755eba69/python/singa/engine.py
----------------------------------------------------------------------
diff --git a/python/singa/engine.py b/python/singa/engine.py
new file mode 100644
index 0000000..a326ab4
--- /dev/null
+++ b/python/singa/engine.py
@@ -0,0 +1,141 @@
+from collections import Counter
+from singa import singa_wrap as singa
+from singa import tensor
+
+
+class GradientFlowController(object):
+    '''
+    Control backward gradients flow by running the method, run_backward()
+
+    '''
+    def __init__(self):
+        pass
+
+    def dependency_check(self, function):
+        '''
+        Compute how many times each 'previous_function'(Operation object) 
influent its next_functions
+
+        though which outputs
+
+        Arg:
+            function: a Operation object which is the termination
+
+        Return:
+            dependencies: a dictionary recording dependencies among 
functions(Operations)
+            seen: a set recording all functions(Operations) observed
+
+        '''
+        dependencies = {}
+        seen = {function}
+        queue = [function]
+        while len(queue) > 0:
+            f = queue.pop()
+            for previous_function, Arg_ID in f.previous_functions:
+                if previous_function not in dependencies:
+                    dependencies[previous_function] = [Counter() for _ in 
previous_function.output_ids]
+                output_idx = previous_function.output_ids[Arg_ID]
+                dependencies[previous_function][output_idx][f] += 1
+                if previous_function not in seen:
+                    queue.append(previous_function)
+                    seen.add(previous_function)
+        return dependencies, seen
+
+    def dependency_release(self, dependencies, previous_function, function, 
Arg_ID):
+        '''
+        To release dependency: if previous_function receive one gradient 
though its
+
+        output(can be found by Arg_ID) from function, the corresponding 
dependency counter
+
+        minus one.
+
+        '''
+        deps = dependencies[previous_function]
+        output_idx = previous_function.output_ids[Arg_ID]
+        output_deps = deps[output_idx]
+        output_deps[function] -= 1
+        if output_deps[function] == 0:
+            del output_deps[function]
+        return output_idx
+
+    def is_ready_for_backward(self, dependencies, function):
+        '''
+        Check if a function(Operation) is ready for backward.
+
+        Return: Trur or Flase
+
+        '''
+        for deps in dependencies[function]:
+            if len(deps) > 0:
+                return False
+        return True
+
+    def run_backward(self, Tensor, grad):
+        '''
+        Run the autograd process.
+
+        Args:
+            Tensor: the object tensor to optimize, usually the loss
+            grad: received gradients
+
+        Return:
+            gradients: a dictionary recording the gradients
+
+        '''
+        ready = [(Tensor.creator, (grad,))]
+        not_ready = {}
+
+        dependencies, seen = self.dependency_check(Tensor.creator)
+
+        while len(ready) > 0:
+            function, grad = ready.pop()
+            gradient_inputs = function._do_backward(*grad)
+            for (previous_function, Arg_ID), gradient_input in 
zip(function.previous_functions, gradient_inputs):
+                if not previous_function.requires_grad:
+                    continue
+                
+                output_index = self.dependency_release(dependencies, 
previous_function, function, Arg_ID)
+                is_ready = self.is_ready_for_backward(dependencies, 
previous_function)
+                
+                if is_ready:
+                    if previous_function in not_ready:
+                        previous_functions_gradients = 
not_ready[previous_function]
+                        if not previous_functions_gradients[output_index]:
+                            previous_functions_gradients[output_index] = 
gradient_input
+                        else:
+                            previous_functions_gradients[output_index] = \
+                                
singa.__add__(previous_functions_gradients[output_index], gradient_input)
+                        del not_ready[previous_function]
+                    else:
+                        assert output_index == 0
+                        previous_functions_gradients = (gradient_input,)
+                    ready.append((previous_function, 
previous_functions_gradients))
+                else:
+                    if previous_function in not_ready:
+                        previous_functions_gradients = 
not_ready[previous_function]
+                    else:
+                        previous_functions_gradients = [None for _ in 
previous_function.output_ids]
+
+                    if not previous_functions_gradients[output_index]:
+                        previous_functions_gradients[output_index] = 
gradient_input
+                    else:
+                        previous_functions_gradients[output_index] = \
+                            
singa.__add__(previous_functions_gradients[output_index], gradient_input)
+
+                    not_ready[previous_function] = previous_functions_gradients
+
+        gradients = {}
+        for f in seen:
+            if isinstance(f, tensor.Initializer):
+                if f.Tensor.grad_outlet is True:
+                    gradients[f.Tensor] = f.grads
+                    f.grads = f.init.Clone()
+        return gradients
+
+
+def gradients(Tensor, out_gradient):
+    '''
+    Compute gradients of Tensor.
+
+    '''
+    Controller = GradientFlowController()
+    return Controller.run_backward(Tensor, out_gradient)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/755eba69/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/python/singa/tensor.py b/python/singa/tensor.py
index 2fcadb4..df29cf5 100644
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -69,6 +69,7 @@ int32 = core_pb2.kInt
 float32 = core_pb2.kFloat32
 
 
+
 class Tensor(object):
     '''Create a Py Tensor, which wraps a swig converted Tensor from CPP Tensor
 
@@ -80,13 +81,17 @@ class Tensor(object):
         device: a swig converted Device instance using the device moduel . If 
it
             is None, then the default host device would be used.
         dtype: data type. currently, most operations only accept kFloat32.
-    '''
+        data: a singa_tensor recording input data.
+        creator: a Operation object which generate this tensor.
+        requires_grad: a bool recording if the creator of tensor require 
gradient.
+        grad_outlet: a bool recording if the tensor is a outlet for gradient.
 
-    def __init__(self, shape=None, device=None, dtype=core_pb2.kFloat32):
+    '''
+    def __init__(self, shape=None, device=None, dtype=core_pb2.kFloat32, 
data=None, creator=None, requires_grad=True,
+                 grad_outlet=False):
         if shape is None:
             # call constructor of singa::Tensor
             self.singa_tensor = singa.Tensor()
-            return
         else:
             assert isinstance(shape, tuple), 'shape should be tuple'
             if device is None:
@@ -94,9 +99,17 @@ class Tensor(object):
                 self.singa_tensor = singa.Tensor(list(shape), device, dtype)
             else:
                 self.singa_tensor = singa.Tensor(list(shape), device, dtype)
-        self.shape = shape
-        self.dtype = dtype
-        self.device = device
+        if data is not None:
+            self.singa_tensor = data
+            if creator is None:
+                creator = Initializer(self, requires_grad)
+
+        self.shape = tuple(self.singa_tensor.shape())
+        self.device = self.singa_tensor.device()
+        self.dtype = self.singa_tensor.data_type()
+
+        self.creator = creator
+        self.grad_outlet = grad_outlet
 
     def ndim(self):
         '''
@@ -384,7 +397,7 @@ class Tensor(object):
         if isinstance(x, Tensor):
             self.singa_tensor /= x.singa_tensor
         else:
-            self.__imul__(1/float(x))
+            self.singa_tensor /= float(x)
         return self
 
     '''
@@ -1102,3 +1115,202 @@ def _call_singa_func(_singa_func, *args):
     new_t.device = new_t.singa_tensor.device()
     new_t.dtype = new_t.singa_tensor.data_type()
     return new_t
+
+
+def copy_from_numpy(singa_tensor, np_array):
+    '''
+    Copy the data from the numpy array.
+    '''
+    assert np_array.size == singa_tensor.Size(), 'tensor shape should be the 
same'
+    if not np_array.ndim == 1:
+        np_array = np_array.flatten()
+    dt = np_array.dtype
+    if dt == np.float32:
+        singa_tensor.CopyFloatDataFromHostPtr(np_array)
+    elif dt == np.int or dt == np.int32:
+        singa_tensor.CopyIntDataFromHostPtr(np_array)
+    else:
+        print('Not implemented yet for ', dt)
+
+
+class Operation(object):
+    '''
+    Wrap normal functions such as dot to realize autograd.
+
+    '''
+    def __init__(self, **operation_params):
+        pass
+
+    def __call__(self, *input):
+        return self._do_forward(*input)
+
+    def _do_forward(self, *input):
+        unpacked_input = tuple(arg.singa_tensor for arg in input)
+        raw_output = self.forward(*unpacked_input)
+        if not isinstance(raw_output, tuple):
+            raw_output = (raw_output,)
+        self.needs_input_grad = tuple(arg.creator.requires_grad for arg in 
input)
+        self.requires_grad = any(self.needs_input_grad)
+        output = tuple(Tensor(data=data, creator=self) for data in raw_output)
+        self.previous_functions = [(arg.creator, id(arg)) for arg in input]
+        self.output_ids = {id(var): i for i, var in enumerate(output)}
+        return output
+
+    def _do_backward(self, grad_output):
+        grad_input = self.backward(grad_output)
+        if not isinstance(grad_input, tuple):
+            grad_input = (grad_input,)
+        return grad_input
+
+    def forward(self, *input):
+        raise NotImplementedError
+
+    def backward(self, *grad_output):
+        raise NotImplementedError
+
+
+class Initializer(Operation):
+    '''
+    For Tensor without creator, Initializer can act as its creator.
+    It is commonly used in feeding training data or initialize parameters like 
weights and bias.
+
+    '''
+    def __init__(self, Tensor, requires_grad):
+        self.Tensor = Tensor
+        self.output_ids = {id(Tensor): 0}
+        self.previous_functions = []
+        self.requires_grad = requires_grad
+        shape = self.Tensor.singa_tensor.shape()
+        self.init = singa.Tensor(list(shape))
+        copy_from_numpy(self.init, np.zeros(shape=shape, dtype=np.float32))
+        self.grads = self.init.Clone()
+
+    def _do_forward(self):
+        raise NotImplementedError
+
+    def _do_backward(self, *dy):
+        assert len(dy) == 1
+        self.grads = singa.__add__(self.grads, dy[0])
+        return tuple()
+
+
+class ReLU(Operation):
+    def forward(self, x):
+        '''
+        forward function for ReLU Operation.
+
+        '''
+        self.input = (x,)
+        return singa.ReLU(x)
+
+    def backward(self, dy):
+        '''
+        backward function for ReLU Operation.
+        '''
+        dx = singa.GTFloat(self.input[0], 0.0)
+        return singa.__mul__(dy, dx)
+def relu(x):
+    return ReLU()(x)[0]
+
+
+class Dot(Operation):
+    def forward(self, x, w):
+        '''
+        forward function for Dot Operation.
+
+        '''
+        self.input = (x, w)
+        return singa.Mult(x, w)
+
+    def backward(self, dy):
+        '''
+        backward function for Dot Operation.
+
+        '''
+        return singa.Mult(dy, self.input[1].T()), 
singa.Mult(self.input[0].T(), dy)
+def dot(x, w):
+    return Dot()(x, w)[0]
+
+
+class Add_Bias(Operation):
+    def forward(self, b, x):
+        '''
+        forward function for Add_Bias Operation.
+
+        '''
+        singa.AddRow(b, x)
+        return x
+
+    def backward(self, dy):
+        '''
+        backward function for Add_Bias Operation.
+
+        '''
+        return singa.Sum(dy, 0), dy
+def add_bias(b, x):
+    return Add_Bias()(b, x)[0]
+
+
+class SoftMax(Operation):
+    def forward(self, x):
+        '''
+        forward function for SoftMax Operation.
+
+        '''
+        self.output = (singa.SoftMax(x),)
+        return self.output[0]
+
+    def backward(self, dy):
+        '''
+        backward function for SoftMax Operation.
+
+        '''
+        # calculations are made on numpy
+        grad = To_Numpy(dy)
+        output = To_Numpy(self.output[0])
+        out_1 = np.einsum('ki,ki->ki', grad, output)
+        medium_out = np.einsum('ki,kj->kij', output, output)
+        out_2 = np.einsum('kij,kj->ki', medium_out, grad)
+        out = out_1 - out_2
+        out_singa = singa.Tensor(out_1.shape)
+        out_singa.CopyFloatDataFromHostPtr(out.flatten())
+        return out_singa
+def softmax(x):
+    return SoftMax()(x)[0]
+
+
+class Cross_Entropy(Operation):
+    def forward(self, pred, target):
+        '''
+        forward function for Cross_Entropy Operation.
+
+        '''
+        loss = singa.Tensor((1,))
+        loss.SetFloatValue(-singa.SumAsFloat(singa.__mul__(target, 
singa.Log(pred)))/pred.shape()[0])
+        self.input = (pred, target)
+        return loss
+
+    def backward(self, dy):
+        '''
+        backward function for Cross_Entropy Operation.
+
+        '''
+        dx = singa.__div__(self.input[1], self.input[0])
+        dx *= float(-1/self.input[0].shape()[0])
+        if not isinstance(dy, singa.Tensor):
+            # dtype of dy: float
+            dx *= dy
+            return dx
+        else:
+            pass  # TODO
+def cross_entropy(y, t):
+    return Cross_Entropy()(y, t)[0]
+
+
+def To_Numpy(x):
+    '''
+    To be used in SoftMax Operation.
+    Convert a singa_tensor to numpy_tensor.
+    '''
+    np_array = x.GetFloatValue(int(x.Size()))
+    return np_array.reshape(x.shape())
\ No newline at end of file

Reply via email to