Repository: incubator-singa
Updated Branches:
  refs/heads/master 22d98bb72 -> f35d217c9


SINGA-264 Extend the FeedForwardNet to accept multiple inputs

Extend FeedForwardNet to support multiple input tensors and output tensors.
The input variable x, of train(x, y), forward(x), predict(x), evaluate(x)
could be a single tensor or a dictionary: layer name -> a single tensor or 
tensor list.
The key is the name of the layer to feed the input data.

The output of out=forward(x, output), would be a single tensor or a dictionary:
layer name -> a single tensor or a tensor list. The key is the name of
the layer to get the output values, e.g, the net has multiple layers
whose outgoing degree is 0. By configuring the argument
'output' as a list of layer names, we can get values of those layers in
'out'.

Passed unittests.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/f35d217c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/f35d217c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/f35d217c

Branch: refs/heads/master
Commit: f35d217c9edf57fe193efcae8ab6bb16ec2dcf5a
Parents: 22d98bb
Author: wang wei <[email protected]>
Authored: Mon Nov 21 17:44:30 2016 +0800
Committer: Wei Wang <[email protected]>
Committed: Mon Nov 21 20:00:02 2016 +0800

----------------------------------------------------------------------
 python/singa/layer.py   |  59 ++++++++++++++++--
 python/singa/net.py     | 145 ++++++++++++++++++++++++++++++++++---------
 src/api/model_layer.i   |   4 --
 test/python/run.py      |  24 +++++++
 test/python/test_net.py |  77 +++++++++++++++++++++++
 5 files changed, 270 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f35d217c/python/singa/layer.py
----------------------------------------------------------------------
diff --git a/python/singa/layer.py b/python/singa/layer.py
index c7f0ce8..730bea0 100644
--- a/python/singa/layer.py
+++ b/python/singa/layer.py
@@ -77,9 +77,8 @@ class Layer(object):
     Args:
         name (str): layer name
     '''
-
     def __init__(self, name, conf=None, **kwargs):
-        if conf == None:
+        if conf is None:
             self.layer = None  # layer converted by swig
             self.name = name  # TODO(wangwei) duplicate with self.conf.name
             self.conf = model_pb2.LayerConf()
@@ -180,7 +179,8 @@ class Layer(object):
         '''Forward propagate through this layer.
 
         Args:
-            flag (int): kTrain or kEval
+            flag: True (kTrain) for training (kEval); False for evaluating;
+                other values for furture use.
             x (Tensor or list<Tensor>): an input tensor if the layer is
                 connected from a single layer; a list of tensors if the layer
                 is connected from multiple layers.
@@ -198,6 +198,11 @@ class Layer(object):
             assert isinstance(x, tensor.Tensor), \
                 'input must be a Tensor or a list of Tensor'
             xs = x.singa_tensor
+        if type(flag) is bool:
+            if flag:
+                flag = model_pb2.kTrain
+            else:
+                flag = model_pb2.kEval
         y = self.layer.Forward(flag, xs)
         if type(y) == list:
             return tensor.from_raw_tensors(y)
@@ -249,6 +254,27 @@ class Layer(object):
         pass
 
 
+class Dummy(Layer):
+    '''A dummy layer that does nothing but just forwards/backwards the data
+    (the input/output is a single tensor).
+    '''
+    def __init__(self, name, input_sample_shape=None):
+        super(Dummy, self).__init__(name)
+        self.output_sample_shape = input_sample_shape
+
+    def get_output_sample_shape(self):
+        return self.output_sample_shape
+
+    def setup(self, input_sample_shape):
+        self.output_sample_shape = input_sample_shape
+        self.has_setup = True
+
+    def forward(self, flag, x):
+        return x
+
+    def backward(self, falg, dy):
+        return dy
+
 class Conv2D(Layer):
     """Construct a layer for 2D convolution.
 
@@ -695,6 +721,15 @@ class Merge(Layer):
         return self.in_shape
 
     def forward(self, flag, inputs):
+        '''Merge all input tensors by summation.
+
+        Args:
+            flag: not used.
+            inputs (list): a list of tensors
+
+        Returns:
+            A single tensor as the sum of all input tensors
+        '''
         assert len(inputs) > 1, 'There must be multiple input tensors'
         self.num_input = len(inputs)
         output = tensor.Tensor()
@@ -708,6 +743,7 @@ class Merge(Layer):
         assert isinstance(grad, tensor.Tensor), 'The input must be Tensor'
         return [grad] * self.num_input, []  # * self.num_input
 
+
 class Split(Layer):
     '''Replicate the input tensor.
 
@@ -730,6 +766,15 @@ class Split(Layer):
         return self.in_shape
 
     def forward(self, flag, input):
+        '''Replicate the input tensor into mutiple tensors.
+
+        Args:
+            flag: not used
+            input: a single input tensor
+
+        Returns:
+            a list a output tensor (each one is a copy of the input)
+        '''
         assert isinstance(input, tensor.Tensor), 'The input must be Tensor'
         outputs = [input] * self.num_output
         return outputs
@@ -795,7 +840,8 @@ class RNN(Layer):
         '''Forward inputs through the RNN.
 
         Args:
-            flag, kTrain or kEval.
+            flag: True(kTrain) for training; False(kEval) for evaluation;
+                others values for future use.
             inputs, <x1, x2,...xn, hx, cx>, where xi is the input tensor for 
the
                 i-th position, its shape is (batch_size, input_feature_length);
                 the batch_size of xi must >= that of xi+1; hx is the initial
@@ -821,6 +867,11 @@ class RNN(Layer):
             assert isinstance(t, tensor.Tensor), \
                 'input must be py Tensor %s' % (type(t))
             tensors.append(t.singa_tensor)
+        if type(flag) is bool:
+            if flag:
+                flag = model_pb2.kTrain
+            else:
+                flag = model_pb2.kEval
         y = self.layer.Forward(flag, tensors)
         return tensor.from_raw_tensors(y)
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f35d217c/python/singa/net.py
----------------------------------------------------------------------
diff --git a/python/singa/net.py b/python/singa/net.py
index caf5732..293e97c 100644
--- a/python/singa/net.py
+++ b/python/singa/net.py
@@ -29,6 +29,7 @@ import cPickle as pickle
 '''For display training information, e.g L1 value of layer data'''
 verbose = False
 
+
 class FeedForwardNet(object):
 
     def __init__(self, loss=None, metric=None):
@@ -72,7 +73,7 @@ class FeedForwardNet(object):
             # print shape
             in_shape = self.src_of_layer[lyr.name][0].get_output_sample_shape()
             lyr.setup(in_shape)
-            print lyr.name, lyr.get_output_sample_shape()
+        print lyr.name, lyr.get_output_sample_shape()
         self.layers.append(lyr)
         return lyr
 
@@ -98,6 +99,19 @@ class FeedForwardNet(object):
         return [spec.name for spec in self.param_specs()]
 
     def train(self, x, y):
+        '''Run BP for one iteration.
+
+        Currently only support nets with a single output layer, and a single
+        loss objective and metric.
+        TODO(wangwei) consider multiple loss objectives and metrics.
+
+        Args:
+            x: input data, a single input Tensor or a dict: layer name -> 
Tensor
+            y: label data, a single input Tensor.
+
+        Returns:
+            gradients of parameters and the loss and metric values.
+        '''
         out = self.forward(kTrain, x)
         l = self.loss.forward(kTrain, out, y)
         if self.metric is not None:
@@ -105,7 +119,16 @@ class FeedForwardNet(object):
         return self.backward(), (l.l1(), m)
 
     def evaluate(self, x, y):
-        """Evaluate the loss and metric of the given data"""
+        '''Evaluate the loss and metric of the given data.
+
+        Currently only support nets with a single output layer, and a single
+        loss objective and metric.
+        TODO(wangwei) consider multiple loss objectives and metrics.
+
+        Args:
+            x: input data, a single input Tensor or a dict: layer name -> 
Tensor
+            y: label data, a single input Tensor.
+        '''
         out = self.forward(kEval, x)
         l = None
         m = None
@@ -118,34 +141,83 @@ class FeedForwardNet(object):
         return l, m
 
     def predict(self, x):
+        '''Forward the input data through each layer to get the values of the
+        output layers.
+
+        Currently only support nets with a single output layer
+
+        Args:
+            x: input data, a single input Tensor or a dict: layer name -> 
Tensor
+
+        Returns:
+            a single output tensor as the prediction result.
+        '''
         xx = self.forward(kEval, x)
         return tensor.softmax(xx)
 
-    def topo_sort(self, cur, src_of_layer, visited=None, order=None):
-        if visited is None:
-            visited = {}
-            for name in src_of_layer.keys():
-                visited[name] = False
-            order = []
-        srcs = src_of_layer[cur.name]
-        for src in srcs:
-            if visited[src.name] is False:
-                visited[src.name] = True
-                self.topo_sort(src, src_of_layer, visited, order)
-        order.append(cur)
-        visited[cur.name] = True
+    def topo_sort(self, layers, src_of_layer):
+        '''Topology sort of layers.
+
+        It would try to preserve the orders of the input layers.
+
+        Args:
+            layers: a list of layers; the layers from the output of the same
+                layer (e.g., slice layer) should be added by users in correct
+                order; This function would not change their order.
+            src_of_layer: a dictionary: src layer name -> a list of src layers
+
+        Returns:
+            A list of ordered layer
+        '''
+        order = []
+        while len(order) < len(layers):
+            for lyr in self.layers:
+                if lyr not in order:
+                    for src in src_of_layer[lyr.name]:
+                        if src not in order:
+                            break
+                    order.append(lyr)
         return order
 
-    def forward(self, flag, x):
-        # print x.l1()
+    def forward(self, flag, x, output=[]):
+        '''Forward the input(s) through every layer.
+
+        If a layer has inputs from other layers and from x, the data from x is
+        ordered before the data from other layers, e.g., if layer 1 -> layer 2,
+        and x['layer 2'] has data, then the input of layer 2 is
+        flatten([x['layer 2'], output of layer 1])
+
+        Args:
+            flag: True for training; False for evaluation; could also be
+                model_pb2.kTrain or model_pb2.kEval, or other values for future
+                use.
+            x: a single SINGA tensor or a dictionary: layer name-> singa tensor
+            output(list): a list of layer names whose output would be returned
+                in addition to the default output
+
+        Returns:
+            if there is only one output layer, return its output tensor(s);
+            else return a dictionary: layer name -> output tensor(s)
+        '''
         if self.ordered_layers is None:
-            self.ordered_layers = self.topo_sort(self.layers[-1],
-                                                 self.src_of_layer)
-        inputs = [x]
-        output_of_layer = {}
+            self.ordered_layers = self.topo_sort(self.layers, 
self.src_of_layer)
+        if type(x) is dict:
+            input_of_layer = x
+        else:
+            assert isinstance(x, tensor.Tensor), \
+                'The inputs of a net should be dict or a single tensor'
+            input_of_layer = {self.ordered_layers[0].name: x}
+        output_of_layer = {}  # outputs generated by each layer
+        ret = {}  # outputs to return
         for cur in self.ordered_layers:
+            inputs = []
+            if cur.name in input_of_layer:
+                if type(input_of_layer[cur.name]) is list:
+                    inputs.extend(input_of_layer[cur.name])
+                else:
+                    inputs.append(input_of_layer[cur.name])
             srcs = self.src_of_layer[cur.name]
-            disp_src = cur.name + '<--'
+            disp_src = ''
             for src in srcs:
                 outs = output_of_layer[src.name]
                 if type(outs) == list:
@@ -153,22 +225,33 @@ class FeedForwardNet(object):
                             'the output from layer %s is empty' % src.name
                     inputs.append(outs[0])
                     outs.pop(0)
+                    if len(outs) == 0:
+                        output_of_layer.pop(src.name)
                 else:
                     inputs.append(outs)
                     output_of_layer[cur.name] = []
-                disp_src += '+' + src.name
-                # del output_of_layer[src.name]
-            # print disp_src
+                    output_of_layer.pop(src.name)
             if len(inputs) == 1:
                 inputs = inputs[0]
-            out= cur.forward(flag, inputs)
+            out = cur.forward(flag, inputs)
             if verbose:
-                print '%s: %f' % (cur.name, out.l1())
+                disp_src = '+'.join([src.name for src in srcs])
+                disp_src += '-->' + cur.name
+                if type(out) is list:
+                    print '%s: %s' % (disp_src,
+                            ' '.join([str(o.l1()) for o in out]))
+                else:
+                    print '%s: %f' % (disp_src, out.l1())
             output_of_layer[cur.name] = out
-            inputs = []
+            if cur.name in output:
+                ret[cur.name] = out
             # print lyr.name, x.l1()
         # print output_of_layer
-        return output_of_layer[self.ordered_layers[-1].name]
+        ret.update(output_of_layer)
+        if len(ret) == 1:
+            return ret.values()[0]
+        else:
+            return ret
 
     def backward(self):
         if self.dst_of_layer is None:
@@ -210,7 +293,7 @@ class FeedForwardNet(object):
             ret.extend(pgrad)
         return ret
 
-    def save(self, f, buffer_size = 10, use_pickle=False):
+    def save(self, f, buffer_size=10, use_pickle=False):
         '''Save model parameters using io/snapshot.
 
         Args:
@@ -234,7 +317,7 @@ class FeedForwardNet(object):
                 val.to_host()
                 sp.write(specs.name, val)
 
-    def load(self, f, buffer_size = 10, use_pickle=False):
+    def load(self, f, buffer_size=10, use_pickle=False):
         '''Load model parameters using io/snapshot.
 
         Please refer to the argument description in save().

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f35d217c/src/api/model_layer.i
----------------------------------------------------------------------
diff --git a/src/api/model_layer.i b/src/api/model_layer.i
index 31b2cb6..3878873 100644
--- a/src/api/model_layer.i
+++ b/src/api/model_layer.i
@@ -70,12 +70,8 @@ class Layer {
     virtual void ToDevice(std::shared_ptr<Device> device);
     virtual void AsType(DataType dtype);
     virtual const Tensor Forward(int flag, const Tensor& input);
-    virtual const std::vector<Tensor> Forward(
-        int flag, const std::vector<Tensor>& inputs);
     virtual const std::pair<Tensor, std::vector<Tensor>> Backward(
         int flag, const Tensor& grad);
-    virtual const std::pair<std::vector<Tensor>, std::vector<Tensor>>
-    Backward(int flag, const vector<Tensor>& grads);
 };
 
 std::shared_ptr<Layer> CreateLayer(const std::string& type);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f35d217c/test/python/run.py
----------------------------------------------------------------------
diff --git a/test/python/run.py b/test/python/run.py
new file mode 100644
index 0000000..ae33fbd
--- /dev/null
+++ b/test/python/run.py
@@ -0,0 +1,24 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+loader = unittest.TestLoader()
+tests = loader.discover('.')
+testRunner = unittest.runner.TextTestRunner()
+testRunner.run(tests)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f35d217c/test/python/test_net.py
----------------------------------------------------------------------
diff --git a/test/python/test_net.py b/test/python/test_net.py
new file mode 100644
index 0000000..53a4f24
--- /dev/null
+++ b/test/python/test_net.py
@@ -0,0 +1,77 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import math
+import numpy as np
+
+from singa import net
+from singa import layer
+from singa import tensor
+from singa import loss
+
+layer.engine = 'singacpp'
+# net.verbose = True
+
+class TestFeedForwardNet(unittest.TestCase):
+
+    def test_single_input_output(self):
+        ffn = net.FeedForwardNet(loss.SoftmaxCrossEntropy())
+        ffn.add(layer.Activation('relu1', input_sample_shape=(2,)))
+        ffn.add(layer.Activation('relu2'))
+        x = np.array([[-1, 1], [1, 1], [-1, -2]], dtype=np.float32)
+        x = tensor.from_numpy(x)
+        y = tensor.Tensor((3,))
+        y.set_value(0)
+        out, _ = ffn.evaluate(x, y)
+        self.assertAlmostEqual(out * 3,
+                - math.log(1.0/(1+math.exp(1))) - math.log(0.5) -math.log(0.5),
+                5);
+
+    def test_mult_inputs(self):
+        ffn = net.FeedForwardNet(loss.SoftmaxCrossEntropy())
+        s1 = ffn.add(layer.Activation('relu1', input_sample_shape=(2,)), [])
+        s2 = ffn.add(layer.Activation('relu2', input_sample_shape=(2,)), [])
+        ffn.add(layer.Merge('merge', input_sample_shape=(2,)), [s1, s2])
+        x1 = tensor.Tensor((2, 2))
+        x1.set_value(1.1)
+        x2 = tensor.Tensor((2, 2))
+        x2.set_value(0.9)
+        out = ffn.forward(False, {'relu1':x1, 'relu2':x2})
+        out = tensor.to_numpy(out)
+        self.assertAlmostEqual(np.average(out), 2)
+
+    def test_mult_outputs(self):
+        ffn = net.FeedForwardNet(loss.SoftmaxCrossEntropy())
+        s1 = ffn.add(layer.Activation('relu1', input_sample_shape=(2,)), [])
+        s2 = ffn.add(layer.Activation('relu2', input_sample_shape=(2,)), [])
+        ffn.add(layer.Merge('merge', input_sample_shape=(2,)), [s1, s2])
+        split = ffn.add(layer.Split('split', 2))
+        ffn.add(layer.Dummy('split1'), split)
+        ffn.add(layer.Dummy('split2'), split)
+        x1 = tensor.Tensor((2, 2))
+        x1.set_value(1.1)
+        x2 = tensor.Tensor((2, 2))
+        x2.set_value(0.9)
+        out = ffn.forward(False, {'relu1':x1, 'relu2':x2})
+        out = tensor.to_numpy(out['split1'])
+        self.assertAlmostEqual(np.average(out), 2)
+
+
+if __name__ == '__main__':
+    unittest.main()

Reply via email to