SINGA-385 Add new python module for optimizers Add the base optimizer and SGD (with momentum).
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/117dfcfd Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/117dfcfd Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/117dfcfd Branch: refs/heads/master Commit: 117dfcfd052bb92142a30b59fc173a2ef6480332 Parents: 2b5c3f7 Author: Wang Wei <[email protected]> Authored: Sat Jul 14 13:07:52 2018 +0800 Committer: Wang Wei <[email protected]> Committed: Mon Jul 16 10:04:54 2018 +0800 ---------------------------------------------------------------------- examples/autograd/resnet.py | 117 ++++++++++++++++++++++++++++-- python/singa/autograd.py | 13 ++++ python/singa/opt.py | 152 +++++++++++++++++++++++++++++++++++++++ python/singa/tensor.py | 12 ++++ 4 files changed, 287 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/117dfcfd/examples/autograd/resnet.py ---------------------------------------------------------------------- diff --git a/examples/autograd/resnet.py b/examples/autograd/resnet.py index 930d9e0..f1fb9d6 100644 --- a/examples/autograd/resnet.py +++ b/examples/autograd/resnet.py @@ -23,6 +23,10 @@ from singa import autograd from singa import tensor from singa import device +from singa import utils +from singa import optimizer + +import numpy as np __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', @@ -60,7 +64,7 @@ class BasicBlock(autograd.Layer): if self.downsample is not None: residual = self.downsample(x) - out += residual + out = autograd.add(out, residual) out = autograd.relu(out) return out @@ -101,7 +105,7 @@ class Bottleneck(autograd.Layer): if self.downsample is not None: residual = self.downsample(x) - out += residual + out = autograd.add(out, residual) out = autograd.relu(out) return out @@ -217,10 +221,109 @@ def resnet152(pretrained=False, **kwargs): return model -if __name__ == '__main__': +def load_dataset(filepath): + print('Loading data file %s' % filepath) + with open(filepath, 'rb') as fd: + try: + cifar10 = pickle.load(fd, encoding='latin1') + except TypeError: + cifar10 = pickle.load(fd) + image = cifar10['data'].astype(dtype=np.uint8) + image = image.reshape((-1, 3, 32, 32)) + label = np.asarray(cifar10['labels'], dtype=np.uint8) + label = label.reshape(label.size, 1) + return image, label + + +def load_train_data(dir_path, num_batches=5): + labels = [] + batchsize = 10000 + images = np.empty((num_batches * batchsize, 3, 32, 32), dtype=np.uint8) + for did in range(1, num_batches + 1): + fname_train_data = dir_path + "/data_batch_{}".format(did) + image, label = load_dataset(fname_train_data) + images[(did - 1) * batchsize:did * batchsize] = image + labels.extend(label) + images = np.array(images, dtype=np.float32) + labels = np.array(labels, dtype=np.int32) + return images, labels + + +def load_test_data(dir_path): + images, labels = load_dataset(dir_path + "/test_batch") + return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32) + + +def accuracy(pred, target): + y = np.argmax(pred, axis=1) + t = np.argmax(target, axis=1) + a = y == t + return np.array(a, 'int').sum() / float(len(t)) + + +def train(data, net, max_epoch, get_lr, weight_decay=1e-5, batch_size=100): + print('Start intialization............') + dev = device.create_cuda_gpu() + + opt = optimizer.SGD(momentum=0.9, weight_decay=weight_decay) + + tx = tensor.Tensor((batch_size, 3, 32, 32), dev) + ty = tensor.Tensor((batch_size,), dev, tensor.int32) + train_x, train_y, test_x, test_y = data + num_train_batch = train_x.shape[0] // batch_size + num_test_batch = test_x.shape[0] // batch_size + idx = np.arange(train_x.shape[0], dtype=np.int32) + for epoch in range(max_epoch): + np.random.shuffle(idx) + loss, acc = 0.0, 0.0 + print('Epoch %d' % epoch) + autograd.training = True + for b in range(num_train_batch): + x = train_x[idx[b * batch_size: (b + 1) * batch_size]] + y = train_y[idx[b * batch_size: (b + 1) * batch_size]] + tx.copy_from_numpy(x) + ty.copy_from_numpy(y) + x = net(tx) + loss = autograd.softmax_cross_entropy(x, ty) + np_loss = tensor.to_numpy(loss) + acc += accuracy(tensor.to_numpy(x), y) + + for p, g in autograd.backwards(loss): + opt.apply_with_lr(epoch, get_lr(epoch), g, p) + # update progress bar + utils.update_progress(b * 1.0 / num_train_batch, + 'training loss = %f' % (np_loss[0])) + + loss, acc = 0.0, 0.0 + autograd.training = True + for b in range(num_test_batch): + x = test_x[b * batch_size: (b + 1) * batch_size] + y = test_y[b * batch_size: (b + 1) * batch_size] + tx.copy_from_numpy(x) + ty.copy_from_numpy(y) + x = net(tx) + l = autograd.softmax_cross_entropy(x, ty) + loss += tensor.to_numpy(l)[0] + acc += accuracy(x, y) + + print('test loss = %f, test accuracy = %f' % + ((loss / num_test_batch), (acc / num_test_batch))) + + +def resnet_lr(epoch): + if epoch < 81: + return 0.1 + elif epoch < 122: + return 0.01 + else: + return 0.001 + +if __name__ == '__main__': model = resnet18() - x = tensor.Tensor((16, 3, 224, 224), device.create_cuda_gpu()) - x.set_value(float(0.1)) - autograd.training = True - y = model(x) + train_x, train_y = load_train_data() + test_x, test_y = load_test_data() + mean = np.average(train_x, axis=0) + train_x -= mean + test_x -= mean + train(model, (train_x, train_y, test_x, test_y), 10, resnet_lr) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/117dfcfd/python/singa/autograd.py ---------------------------------------------------------------------- diff --git a/python/singa/autograd.py b/python/singa/autograd.py index c77c174..63e3771 100755 --- a/python/singa/autograd.py +++ b/python/singa/autograd.py @@ -347,6 +347,19 @@ def add_bias(x, b, axis=0): return AddBias(axis)(x, b)[0] +class Add(Operation): + + def forward(self, a, b): + return a + b + + def backward(self, dy): + return dy, dy + + +def add(a, b): + return Add()(a, b)[0] + + class SoftMax(Operation): ''' Apply SoftMax for each row of the Tensor or each column of the Tensor http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/117dfcfd/python/singa/opt.py ---------------------------------------------------------------------- diff --git a/python/singa/opt.py b/python/singa/opt.py new file mode 100644 index 0000000..bf04b09 --- /dev/null +++ b/python/singa/opt.py @@ -0,0 +1,152 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +'''This module includes a set of optimizers for updating model parameters. +It replaces the old optimizers from optimizer.py''' + +from singa import tensor + + +class Optimizer(object): + r"""Base optimizer. + + Args: + config (Dict): specify the default values of configurable variables. + """ + + def __init__(self, config): + self.config = config + self.step = 0 + self.param2config = {} + + def update(self, param, grad): + r"""Update the param values with given gradients. + + Args: + param(Tensor): param values to be updated in-place + grad(Tensor): param gradients; the values may be updated + in this function; do not use it anymore + """ + pass + + def step(self): + r"""To increment the step counter""" + self.step += 1 + + def register(self, param_group, config): + for param in param_group: + assert param not in self.param2config, 'param is already registered' + + self.param2config[param] = config + + def load(self): + pass + + def save(self): + pass + + +class SGD(Optimizer): + r"""Implements stochastic gradient descent (optionally with momentum). + + Nesterov momentum is based on the formula from + `On the importance of initialization and momentum in deep learning`__. + + Args: + lr(float): learning rate + momentum(float, optional): momentum factor(default: 0) + weight_decay(float, optional): weight decay(L2 penalty)(default: 0) + dampening(float, optional): dampening for momentum(default: 0) + nesterov(bool, optional): enables Nesterov momentum(default: False) + + Example: + >> > from singa import opt + >> > optimizer = opt.SGD(lr=0.1, momentum=0.9) + >> > optimizer.update() + + __ http: // www.cs.toronto.edu / %7Ehinton / absps / momentum.pdf + + .. note:: + The implementation of SGD with Momentum / Nesterov subtly differs from + Sutskever et. al. and implementations in some other frameworks. + + Considering the specific case of Momentum, the update can be written as + + .. math:: + v = \rho * v + g \\ + p = p - lr * v + + where p, g, v and: math: `\rho` denote the parameters, gradient, + velocity, and momentum respectively. + + This is in contrast to Sutskever et. al. and + other frameworks which employ an update of the form + + .. math:: + v = \rho * v + lr * g \\ + p = p - v + + The Nesterov version is analogously modified. + """ + + def __init__(self, lr=0.1, momentum=0, dampening=0, + weight_decay=0, nesterov=False): + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if weight_decay < 0.0: + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict(lr=lr, momentum=momentum, dampening=dampening, + weight_decay=weight_decay, nesterov=nesterov) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError( + "Nesterov momentum requires a momentum and zero dampening") + super(SGD, self).__init__(defaults) + + def update(self, param, grad): + """Performs a single optimization step. + + Arguments: + param(Tensor): param values to be update in-place + grad(Tensor): param gradients; the values may be updated + in this function; cannot use it anymore + """ + group = self.param2group[param] + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + if weight_decay != 0: + grad += param * weight_decay + if momentum != 0: + param_state = self.state[param] + if 'momentum_buffer' not in param_state: + buf = param_state[ + 'momentum_buffer'] = tensor.zeros_like(param) + buf *= momentum + buf += grad + else: + buf = param_state['momentum_buffer'] + buf *= momentum + buf += (1 - dampening) * grad + if nesterov: + grad += momentum * buf + else: + grad = buf + param -= grad * group['lr'] http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/117dfcfd/python/singa/tensor.py ---------------------------------------------------------------------- diff --git a/python/singa/tensor.py b/python/singa/tensor.py index 46a47b7..441431f 100644 --- a/python/singa/tensor.py +++ b/python/singa/tensor.py @@ -602,6 +602,18 @@ def from_raw_tensors(tt): return ret +def zeros_like(t): + ret = Tensor(t.shape, t.device, t.dtype) + ret.set_value(float(0)) + return ret + + +def ones_like(t): + ret = Tensor(t.shape, t.device, t.dtype) + ret.set_value(float(1)) + return ret + + def product(shape): return reduce(lambda x, y: x * y, shape)
