SINGA-237 New documentation files for SINGA v1.0 Added readme file for the cifar-10 examples. Updated the uniform and gaussian methods in initializer.py to include the fan_in and fan_out arguments. Reformat some python files.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/d3a57cfc Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/d3a57cfc Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/d3a57cfc Branch: refs/heads/dev Commit: d3a57cfc2b71abadf992e9f0900a4051da8e4232 Parents: 8cd5530 Author: Wei Wang <[email protected]> Authored: Sun Aug 14 21:41:16 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Sun Aug 14 21:41:16 2016 +0800 ---------------------------------------------------------------------- doc/docs/examples.rst | 6 -- doc/docs/index.rst | 2 +- doc/docs/initializer.rst | 2 +- examples/char-rnn/README.md | 2 +- examples/char-rnn/train.py | 103 +++++++++++++++++++++-------------- examples/cifar10/alexnet.py | 48 +++++++++++++--- examples/cifar10/predict.py | 10 ++-- examples/cifar10/vgg.py | 12 ++-- examples/index.rst | 4 ++ src/python/singa/initializer.py | 85 ++++++++++++++--------------- src/python/singa/optimizer.py | 4 +- 11 files changed, 157 insertions(+), 121 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3a57cfc/doc/docs/examples.rst ---------------------------------------------------------------------- diff --git a/doc/docs/examples.rst b/doc/docs/examples.rst deleted file mode 100644 index b0b2af8..0000000 --- a/doc/docs/examples.rst +++ /dev/null @@ -1,6 +0,0 @@ -Examples -======== - -.. toctree:: - - examples/index http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3a57cfc/doc/docs/index.rst ---------------------------------------------------------------------- diff --git a/doc/docs/index.rst b/doc/docs/index.rst index 2294054..11f0ebb 100644 --- a/doc/docs/index.rst +++ b/doc/docs/index.rst @@ -12,4 +12,4 @@ English loss metric optimizer - examples + examples/index http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3a57cfc/doc/docs/initializer.rst ---------------------------------------------------------------------- diff --git a/doc/docs/initializer.rst b/doc/docs/initializer.rst index a190702..f334497 100644 --- a/doc/docs/initializer.rst +++ b/doc/docs/initializer.rst @@ -5,7 +5,7 @@ Python API ---------- .. automodule:: singa.initializer - :members: + :members: uniform, gaussian :member-order: bysource CPP API http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3a57cfc/examples/char-rnn/README.md ---------------------------------------------------------------------- diff --git a/examples/char-rnn/README.md b/examples/char-rnn/README.md index d4cfa30..f6e5edc 100644 --- a/examples/char-rnn/README.md +++ b/examples/char-rnn/README.md @@ -1,4 +1,4 @@ -# Train Char-RNN using SINGA +# Train Char-RNN over plain text Recurrent neural networks (RNN) are widely used for modelling sequential data, e.g., natural language sentences. This example describes how to implement a RNN http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3a57cfc/examples/char-rnn/train.py ---------------------------------------------------------------------- diff --git a/examples/char-rnn/train.py b/examples/char-rnn/train.py index fb5e71f..1273a57 100644 --- a/examples/char-rnn/train.py +++ b/examples/char-rnn/train.py @@ -19,8 +19,6 @@ The model is created following https://github.com/karpathy/char-rnn The train file could be any text file, e.g., http://cs.stanford.edu/people/karpathy/char-rnn/ ''' -import sys -import os import cPickle as pickle import numpy as np import argparse @@ -32,12 +30,12 @@ from singa import device from singa import tensor from singa import optimizer from singa import initializer -from singa.proto import core_pb2 from singa.proto import model_pb2 from singa import utils class Data(object): + def __init__(self, fpath, batch_size=32, seq_length=100, train_ratio=0.8): '''Data object for loading a plain text file. @@ -49,8 +47,8 @@ class Data(object): self.raw_data = open(fpath, 'r').read() # read text file chars = list(set(self.raw_data)) self.vocab_size = len(chars) - self.char_to_idx = {ch:i for i, ch in enumerate(chars)} - self.idx_to_char = {i:ch for i, ch in enumerate(chars)} + self.char_to_idx = {ch: i for i, ch in enumerate(chars)} + self.idx_to_char = {i: ch for i, ch in enumerate(chars)} data = [self.char_to_idx[c] for c in self.raw_data] # seq_length + 1 for the data + label nsamples = len(data) / (1 + seq_length) @@ -69,10 +67,10 @@ class Data(object): def numpy2tensors(npx, npy, dev): '''batch, seq, dim -- > seq, batch, dim''' - tmpx=np.swapaxes(npx, 0, 1) - tmpy=np.swapaxes(npy, 0, 1) - inputs=[] - labels=[] + tmpx = np.swapaxes(npx, 0, 1) + tmpy = np.swapaxes(npy, 0, 1) + inputs = [] + labels = [] for t in range(tmpx.shape[0]): x = tensor.from_numpy(tmpx[t]) y = tensor.from_numpy(tmpy[t]) @@ -99,25 +97,36 @@ def get_lr(epoch): return 0.001 / float(1 << (epoch / 50)) -def train(data, max_epoch, hidden_size =100, seq_length=100, batch_size=16, - num_stacks=1, lr=0.001, dropout = 0.5, model_path='model.bin'): +def train(data, max_epoch, hidden_size=100, seq_length=100, batch_size=16, + num_stacks=1, lr=0.001, dropout=0.5, model_path='model.bin'): # SGD with L2 gradient normalization opt = optimizer.SGD(constraint=optimizer.L2Constraint(5)) cuda = device.create_cuda_gpu() - rnn = layer.LSTM(name='lstm', hidden_size=hidden_size, num_stacks=num_stacks, - dropout=dropout, input_sample_shape=(data.vocab_size,)) + rnn = layer.LSTM( + name='lstm', + hidden_size=hidden_size, + num_stacks=num_stacks, + dropout=dropout, + input_sample_shape=( + data.vocab_size, + )) rnn.to_device(cuda) print 'created rnn' rnn_w = rnn.param_values()[0] - initializer.uniform(rnn_w, -0.08, 0.08) # init all rnn parameters + rnn_w.uniform(-0.08, 0.08) # init all rnn parameters print 'rnn weight l1 = %f' % (rnn_w.l1()) - dense = layer.Dense('dense', data.vocab_size, input_sample_shape=(hidden_size,)) + dense = layer.Dense( + 'dense', + data.vocab_size, + input_sample_shape=( + hidden_size, + )) dense.to_device(cuda) dense_w = dense.param_values()[0] dense_b = dense.param_values()[1] print 'dense w ', dense_w.shape print 'dense b ', dense_b.shape - initializer.xavier(dense_w) # init weight matrix using Xavier + initializer.uniform(dense_w, dense_w.shape[0], dense_w.shape[1]) print 'dense weight l1 = %f' % (dense_w.l1()) dense_b.set_value(0.0) print 'dense b l1 = %f' % (dense_b.l1()) @@ -125,18 +134,18 @@ def train(data, max_epoch, hidden_size =100, seq_length=100, batch_size=16, g_dense_w = tensor.Tensor(dense_w.shape, cuda) g_dense_b = tensor.Tensor(dense_b.shape, cuda) - lossfun = loss.SoftmaxCrossEntropy(); + lossfun = loss.SoftmaxCrossEntropy() for epoch in range(max_epoch): train_loss = 0 for b in range(data.num_train_batch): batch = data.train_dat[b * batch_size: (b + 1) * batch_size] inputs, labels = convert(batch, batch_size, seq_length, - data.vocab_size, cuda) + data.vocab_size, cuda) inputs.append(tensor.Tensor()) inputs.append(tensor.Tensor()) outputs = rnn.forward(model_pb2.kTrain, inputs)[0:-2] - grads=[] + grads = [] batch_loss = 0 g_dense_w.set_value(0.0) g_dense_b.set_value(0.0) @@ -149,52 +158,62 @@ def train(data, max_epoch, hidden_size =100, seq_length=100, batch_size=16, grads.append(grad) g_dense_w += gwb[0] g_dense_b += gwb[1] - #print output.l1(), act.l1() - utils.update_progress(b * 1.0 / data.num_train_batch, - 'training loss = %f' % (batch_loss / seq_length)) + # print output.l1(), act.l1() + utils.update_progress( + b * 1.0 / data.num_train_batch, 'training loss = %f' % + (batch_loss / seq_length)) train_loss += batch_loss grads.append(tensor.Tensor()) grads.append(tensor.Tensor()) - g_rnn_w=rnn.backward(model_pb2.kTrain, grads)[1][0] + g_rnn_w = rnn.backward(model_pb2.kTrain, grads)[1][0] dense_w, dense_b = dense.param_values() opt.apply_with_lr(epoch, get_lr(epoch), g_rnn_w, rnn_w, 'rnnw') - opt.apply_with_lr(epoch, get_lr(epoch), g_dense_w, dense_w, 'dense_w') - opt.apply_with_lr(epoch, get_lr(epoch), g_dense_b, dense_b, 'dense_b') - print '\nEpoch %d, train loss is %f' % (epoch, - train_loss / data.num_train_batch / seq_length) + opt.apply_with_lr( + epoch, get_lr(epoch), + g_dense_w, dense_w, 'dense_w') + opt.apply_with_lr( + epoch, get_lr(epoch), + g_dense_b, dense_b, 'dense_b') + print '\nEpoch %d, train loss is %f' % \ + (epoch, train_loss / data.num_train_batch / seq_length) + eval_loss = 0 for b in range(data.num_test_batch): batch = data.val_dat[b * batch_size: (b + 1) * batch_size] inputs, labels = convert(batch, batch_size, seq_length, - data.vocab_size, cuda) + data.vocab_size, cuda) inputs.append(tensor.Tensor()) inputs.append(tensor.Tensor()) outputs = rnn.forward(model_pb2.kEval, inputs)[0:-2] for output, label in zip(outputs, labels): output = dense.forward(model_pb2.kEval, output) - eval_loss += lossfun.forward(model_pb2.kEval, output, label).l1() - print 'Epoch %d, evaluation loss is %f' % (epoch, - eval_loss / data.num_test_batch / seq_length) + eval_loss += lossfun.forward(model_pb2.kEval, + output, label).l1() + print 'Epoch %d, evaluation loss is %f' % \ + (epoch, eval_loss / data.num_test_batch / seq_length) # checkpoint the file model with open(model_path, 'wb') as fd: print 'saving model to %s' % model_path - d={} - for name, w in zip(['rnn_w', 'dense_w', 'dense_b'], [rnn_w, dense_w, dense_b]): + d = {} + for name, w in zip( + ['rnn_w', 'dense_w', 'dense_b'], + [rnn_w, dense_w, dense_b]): w.to_host() - d[name]=tensor.to_numpy(w) - d['idx_to_char']=data.idx_to_char - d['char_to_idx']=data.char_to_idx - d['hidden_size']=hidden_size - d['num_stacks']=num_stacks - d['dropout']=dropout + d[name] = tensor.to_numpy(w) + d['idx_to_char'] = data.idx_to_char + d['char_to_idx'] = data.char_to_idx + d['hidden_size'] = hidden_size + d['num_stacks'] = num_stacks + d['dropout'] = dropout pickle.dump(d, fd) if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Train multi-stack LSTM for '\ - 'modeling character sequence from plain text files') + parser = argparse.ArgumentParser( + description='Train multi-stack LSTM for ' + 'modeling character sequence from plain text files') parser.add_argument('data', type=str, help='training file') parser.add_argument('-b', type=int, default=32, help='batch_size') parser.add_argument('-l', type=int, default=64, help='sequence length') @@ -204,4 +223,4 @@ if __name__ == '__main__': args = parser.parse_args() data = Data(args.data, batch_size=args.b, seq_length=args.l) train(data, args.m, hidden_size=args.d, num_stacks=args.s, - seq_length=args.l, batch_size=args.b) + seq_length=args.l, batch_size=args.b) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3a57cfc/examples/cifar10/alexnet.py ---------------------------------------------------------------------- diff --git a/examples/cifar10/alexnet.py b/examples/cifar10/alexnet.py index ddad1d5..34da95d 100644 --- a/examples/cifar10/alexnet.py +++ b/examples/cifar10/alexnet.py @@ -20,12 +20,8 @@ Following the same setting for hyper-parameters and data pre-processing, the fin validation accuracy would be about 82%. ''' -import sys -import os - # sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python')) from singa import layer -from singa import initializer from singa import metric from singa import loss from singa import net as ffnet @@ -40,23 +36,57 @@ def create_net(use_cpu=False): W1_specs = {'init': 'gaussian', 'mean': 0, 'std': 0.01} W2_specs = {'init': 'gaussian', 'mean': 0, 'std': 0.01, 'decay_mult': 250} b_specs = {'init': 'constant', 'value': 0, 'lt_mult': 2} - net.add(layer.Conv2D('conv1', 32, 5, 1, W_specs=W0_specs.copy(), b_specs=b_specs.copy(), pad=2, input_sample_shape=(3,32,32,))) + net.add( + layer.Conv2D( + 'conv1', + 32, + 5, + 1, + W_specs=W0_specs.copy(), + b_specs=b_specs.copy(), + pad=2, + input_sample_shape=( + 3, + 32, + 32, + ))) net.add(layer.MaxPooling2D('pool1', 3, 2, pad=1)) net.add(layer.Activation('relu1')) net.add(layer.LRN(name='lrn1')) - net.add(layer.Conv2D('conv2', 32, 5, 1, W_specs=W1_specs.copy(), b_specs=b_specs.copy(), pad=2)) + net.add( + layer.Conv2D( + 'conv2', + 32, + 5, + 1, + W_specs=W1_specs.copy(), + b_specs=b_specs.copy(), + pad=2)) net.add(layer.Activation('relu2')) net.add(layer.MaxPooling2D('pool2', 3, 2, pad=1)) net.add(layer.LRN('lrn2')) - net.add(layer.Conv2D('conv3', 64, 5, 1, W_specs=W1_specs.copy(), b_specs=b_specs.copy(), pad=2)) + net.add( + layer.Conv2D( + 'conv3', + 64, + 5, + 1, + W_specs=W1_specs.copy(), + b_specs=b_specs.copy(), + pad=2)) net.add(layer.Activation('relu3')) net.add(layer.MaxPooling2D('pool3', 3, 2, pad=1)) net.add(layer.Flatten('flat')) - net.add(layer.Dense('dense', 10, W_specs=W2_specs.copy(), b_specs=b_specs.copy())) + net.add( + layer.Dense( + 'dense', + 10, + W_specs=W2_specs.copy(), + b_specs=b_specs.copy())) for (p, specs) in zip(net.param_values(), net.param_specs()): filler = specs.filler if filler.type == 'gaussian': - initializer.gaussian(p, filler.mean, filler.std) + p.gaussian(filler.mean, filler.std) else: p.set_value(0) print specs.name, filler.type, p.l1() http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3a57cfc/examples/cifar10/predict.py ---------------------------------------------------------------------- diff --git a/examples/cifar10/predict.py b/examples/cifar10/predict.py index 8a9ea4e..307a610 100644 --- a/examples/cifar10/predict.py +++ b/examples/cifar10/predict.py @@ -16,28 +16,26 @@ # ============================================================================= import cPickle as pickle import numpy as np -import sys -import os -#sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python')) +# sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python')) from singa import device from singa import tensor import net as ffnet -def predict(net, images, cuda, topk=5): +def predict(net, images, dev, topk=5): '''Predict the label of each image. Args: net, a pretrained neural net images, a batch of images [batch_size, 3, 32, 32], which have been pre-processed - cuda, the cuda device + dev, the training device topk, return the topk labels for each image. ''' x = tensor.from_numpy(images.astype(np.float32)) - x.to_device(cuda) + x.to_device(dev) y = net.predict(x) y.to_host() y = tensor.to_numpy(y) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3a57cfc/examples/cifar10/vgg.py ---------------------------------------------------------------------- diff --git a/examples/cifar10/vgg.py b/examples/cifar10/vgg.py index 327592f..29a4b40 100644 --- a/examples/cifar10/vgg.py +++ b/examples/cifar10/vgg.py @@ -20,11 +20,7 @@ The performance could be improved by tuning some hyper-parameters, including learning rate, weight decay, max_epoch, parameter initialization, etc. """ -import sys -import os -import math - -#sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python')) +# sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python')) from singa import layer from singa import initializer @@ -86,11 +82,11 @@ def create_net(use_cpu=False): elif 'var' in name: p.set_value(1.0) elif 'gamma' in name: - initializer.uniform(p, 0, 1) + p.uniform(0, 1) elif 'conv' in name: - initializer.gaussian(p, 0, math.sqrt(2.0/(9.0 * p.shape[0]))) + initializer.gaussian(p, 0, 3 * 3 * p.shape[0]) else: - initializer.gaussian(p, 0, 0.02) + p.gaussian(0, 0.02) else: p.set_value(0) print name, p.l1() http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3a57cfc/examples/index.rst ---------------------------------------------------------------------- diff --git a/examples/index.rst b/examples/index.rst index d6faf5d..4bb5b49 100644 --- a/examples/index.rst +++ b/examples/index.rst @@ -1,5 +1,9 @@ +Examples +======== + .. toctree:: + cifar10/README char-rnn/README imagenet/README http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3a57cfc/src/python/singa/initializer.py ---------------------------------------------------------------------- diff --git a/src/python/singa/initializer.py b/src/python/singa/initializer.py index 277fd2f..fb99663 100644 --- a/src/python/singa/initializer.py +++ b/src/python/singa/initializer.py @@ -23,77 +23,68 @@ Example usages:: from singa import initializer x = tensor.Tensor((3, 5)) - initializer.xavier(x) + initializer.uniform(x, 3, 5) # use both fan_in and fan_out + initializer.uniform(x, 3, 0) # use only fan_in ''' import math -''' -TODO(wangwei) update the uniform and gaussian initializers - def uniform(t, fan_in=0, fan_out=0): - typically, for conv layer weight: fan_in = nb_filter * kh * kw, - fan_out = nb_channel * kh * kw - for dense layer weight, fan_in = input_feature_length, - fan_out = output_feature_length - # Ref: [Bengio and Glorot 2010]: Understanding the difficulty of + '''Initialize the values of the input tensor following a uniform + distribution with specific bounds. + + Args: + fan_in(int): for the weight Tensor of a convolution layer, + fan_in = nb_channel * kh * kw; for dense layer, + fan_in = input_feature_length + fan_out(int): for the convolution layer weight Tensor, + fan_out = nb_filter * kh * kw; for the weight Tensor of a dense + layer, fan_out = output_feature_length + + Ref: [Bengio and Glorot 2010]: Understanding the difficulty of training deep feedforward neuralnetworks. - assert fan_in >0 or fan_out > 0, \ + ''' + assert fan_in > 0 or fan_out > 0, \ 'fan_in and fan_out cannot be 0 at the same time' - avg = 1 + avg = 2 if fan_in * fan_out == 0: - avg = 2 - x = math.sqrt(3.0f * avg / (fan_in + fan_out)) + avg = 1 + x = math.sqrt(3.0 * avg / (fan_in + fan_out)) t.uniform(-x, x) def gaussian(t, fan_in=0, fan_out=0): - typically, for conv layer weight: fan_in = nb_filter * kh * kw, - fan_out = nb_channel * kh * kw - for dense layer weight, fan_in = input_feature_length, - fan_out = output_feature_length + '''Initialize the values of the input tensor following a Gaussian + distribution with specific std. + + Args: + fan_in(int): for the weight Tensor of a convolution layer, + fan_in = nb_channel * kh * kw; for dense layer, + fan_in = input_feature_length + fan_out(int): for the convolution layer weight Tensor, + fan_out = nb_filter * kh * kw; for the weight Tensor of a dense + layer, fan_out = output_feature_length Ref Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification - - assert fan_in >0 or fan_out > 0, \ + ''' + assert fan_in > 0 or fan_out > 0, \ 'fan_in and fan_out cannot be 0 at the same time' - avg = 1 + avg = 2 if fan_in * fan_out == 0: - avg = 2 - std = math.sqrt(2.0f * avg / (fan_in + fan_out)) + avg = 1 + std = math.sqrt(2.0 * avg / (fan_in + fan_out)) t.gaussian(0, std) -''' - - -def uniform(t, low=0, high=1): - '''Initialize the parameter values following an Uniform distribution. - - Args: - t (Tensor): the parater tensor - low (float): lower bound - high (float): higher bound - ''' - t.uniform(low, high) - - -def gaussian(t, mean=0, std=0.01): - '''Initialize the parameter values following an Gaussian distribution. - - Args: - t (Tensor): the parater tensor - mean (float): mean of the distribution - std (float): standard variance - ''' - t.gaussian(mean, std) def xavier(t): '''Initialize the matrix parameter follow a Uniform distribution from [-sqrt(6/(fan_in + fan_out)), sqrt(6/(fan_in + fan_out))]. + Deprecated. Please use uniform() + Args: t (Tensor): the parater tensor ''' @@ -106,6 +97,8 @@ def glorot(t): '''Initialize the matrix parameter follow a Gaussian distribution with mean = 0 and std = sqrt(2.0 / (nb_row + nb_col)) + Deprecated. Please use gaussian() + Args: t (Tensor): the parater tensor ''' @@ -118,6 +111,8 @@ def msra(t): '''Initialize the matrix parameter follow a Guassian distribution with mean = 0, std = math.sqrt(2.0 / nb_row). + Deprecated. Please use gaussian() + Ref [He, Zhang, Ren and Sun 2015]: Specifically accounts for ReLU nonlinearities. http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3a57cfc/src/python/singa/optimizer.py ---------------------------------------------------------------------- diff --git a/src/python/singa/optimizer.py b/src/python/singa/optimizer.py index 5d38997..7c8cc39 100644 --- a/src/python/singa/optimizer.py +++ b/src/python/singa/optimizer.py @@ -44,8 +44,8 @@ class Optimizer(object): 1. construct the optimizer 2. (optional) register each parameter with its specs. - 3. use the optimizer to update parameter values given parameter - gradients and other optional info + 3. use the optimizer to update parameter values given parameter gradients + and other optional info The subclasses should override the apply_with_lr function to do the real parameter udpate.
