Repository: incubator-singa Updated Branches: refs/heads/dev db5478efa -> 28678ae83
SINGA-231 Batchnormlized VGG model for cifar-10 In this ticket, we implemented a batch normalized VGG model for cifar10 dataset (refer to http://torch.ch/blog/2015/07/30/cifar.html). * +vgg-parallel.cc for parallel training * +vgg.py using python language * fix a bug in ResetLike() method in tensor.h, which before did not reset shape. * fix a bug in local_updater.cc, which may cause race condition when multi-threads try to initialize mutexes concurrently. * revise batch nomalization layer to support 2D tensor input Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/bc3b74b3 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/bc3b74b3 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/bc3b74b3 Branch: refs/heads/dev Commit: bc3b74b3662230f867c42344f0600498368f4785 Parents: db5478e Author: WANG Ji <[email protected]> Authored: Sat Aug 6 17:36:28 2016 +0800 Committer: WANG Ji <[email protected]> Committed: Mon Aug 8 11:44:01 2016 +0800 ---------------------------------------------------------------------- examples/cifar10/CMakeLists.txt | 5 + examples/cifar10/train_vgg_cifar10.py | 162 ++++++++++++++ examples/cifar10/vgg-parallel.cc | 333 +++++++++++++++++++++++++++++ examples/cifar10/vgg.py | 52 +++++ src/core/tensor/tensor.cc | 2 +- src/model/layer/batchnorm.cc | 25 ++- src/model/layer/batchnorm.h | 3 +- src/model/layer/cudnn_batchnorm.cc | 31 ++- src/model/updater/local_updater.cc | 1 + src/python/singa/layer.py | 10 +- src/python/singa/net.py | 6 +- 11 files changed, 613 insertions(+), 17 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/examples/cifar10/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/examples/cifar10/CMakeLists.txt b/examples/cifar10/CMakeLists.txt index 92f884c..76c0b73 100644 --- a/examples/cifar10/CMakeLists.txt +++ b/examples/cifar10/CMakeLists.txt @@ -10,4 +10,9 @@ ADD_EXECUTABLE(alexnet-parallel alexnet-parallel.cc) ADD_DEPENDENCIES(alexnet-parallel singa_core singa_model singa_utils) TARGET_LINK_LIBRARIES(alexnet-parallel singa_core singa_utils singa_model protobuf ${SINGA_LIBKER_LIBS}) SET_TARGET_PROPERTIES(alexnet-parallel PROPERTIES LINK_FLAGS "${LINK_FLAGS} -pthread") + +ADD_EXECUTABLE(vgg-parallel vgg-parallel.cc) +ADD_DEPENDENCIES(vgg-parallel singa_core singa_model singa_utils) +TARGET_LINK_LIBRARIES(vgg-parallel singa_core singa_utils singa_model protobuf ${SINGA_LIBKER_LIBS}) +SET_TARGET_PROPERTIES(vgg-parallel PROPERTIES LINK_FLAGS "${LINK_FLAGS} -pthread") ENDIF(USE_CUDNN) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/examples/cifar10/train_vgg_cifar10.py ---------------------------------------------------------------------- diff --git a/examples/cifar10/train_vgg_cifar10.py b/examples/cifar10/train_vgg_cifar10.py new file mode 100644 index 0000000..e9df04e --- /dev/null +++ b/examples/cifar10/train_vgg_cifar10.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +""" CIFAR10 dataset is at https://www.cs.toronto.edu/~kriz/cifar.html. +It includes 5 binary dataset, each contains 10000 images. 1 row (1 image) +includes 1 label & 3072 pixels. 3072 pixels are 3 channels of a 32x32 image +""" + +import cPickle +import numpy as np +import os +import sys +import math + +sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python')) +from singa import initializer +from singa import utils +from singa import optimizer +from singa import device +from singa import tensor +from singa.proto import core_pb2 + +import vgg + + +def load_dataset(filepath): + print 'Loading data file %s' % filepath + with open(filepath, 'rb') as fd: + cifar10 = cPickle.load(fd) + image = cifar10['data'].astype(dtype=np.uint8) + image = image.reshape((-1, 3, 32, 32)) + label = np.asarray(cifar10['labels'], dtype=np.uint8) + label = label.reshape(label.size, 1) + return image, label + + +def load_train_data(dir_path, num_batches=5): + labels = [] + batchsize = 10000 + images = np.empty((num_batches * batchsize, 3, 32, 32), dtype=np.uint8) + for did in range(1, num_batches + 1): + fname_train_data = dir_path + "/data_batch_{}".format(did) + image, label = load_dataset(fname_train_data) + images[(did - 1) * batchsize:did * batchsize] = image + labels.extend(label) + images = np.array(images, dtype=np.float32) + labels = np.array(labels, dtype=np.int32) + return images, labels + + +def load_test_data(dir_path): + images, labels = load_dataset(dir_path + "/test_batch") + return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32) + + +def get_lr(epoch): + return 0.01 / float(1 << ((epoch / 30))) + #if epoch < 100: + # return 0.01 + #elif epoch < 150: + # return 0.005 + #elif epoch < 200: + # return 0.001 + #elif epoch < 250: + # return 0.0001 + + +def train(data_dir, net, num_epoch=250, batch_size=128): + print 'Creating Device............' + cuda = device.create_cuda_gpus(2)[1] + net.to_device(cuda) + print 'Start intialization............' + opt = optimizer.SGD(momentum=0.9, weight_decay=0.0005) + for (p, name) in zip(net.param_values(), net.param_names()): + print name, p.shape + if len(p.shape) > 1: + if 'mean' in name or 'beta' in name: + p.set_value(0.0) + elif 'var' in name: + p.set_value(1.0) + elif 'gamma' in name: + initializer.uniform(p, 0, 1) + elif 'conv' in name: + initializer.gaussian(p, 0, math.sqrt(2.0/(9.0 * p.shape[0]))) + else: + initializer.gaussian(p, 0, 0.02) + + #stdv = 1.0/math.sqrt(p.shape[1]) + #initializer.uniform(p, -stdv, stdv) + else: + p.set_value(0) + #print specs.name, filler.type, p.l1() + print name, p.l1() + print 'Loading data ..................' + train_x, train_y = load_train_data(data_dir) + test_x, test_y = load_test_data(data_dir) + mean = train_x.mean() + std = train_x.std() + train_x -= mean + test_x -= mean + train_x /= std + test_x /= std + + tx = tensor.Tensor((batch_size, 3, 32, 32), cuda) + ty = tensor.Tensor((batch_size,), cuda, core_pb2.kInt) + num_train_batch = train_x.shape[0] / batch_size + num_test_batch = test_x.shape[0] / batch_size + idx = np.arange(train_x.shape[0], dtype=np.int32) + for epoch in range(num_epoch): + np.random.shuffle(idx) + loss, acc = 0.0, 0.0 + print 'Epoch %d' % epoch + for b in range(num_train_batch): + x = train_x[idx[b * batch_size: (b + 1) * batch_size]] + y = train_y[idx[b * batch_size: (b + 1) * batch_size]] + tx.copy_from_numpy(x) + ty.copy_from_numpy(y) + grads, (l, a) = net.train(tx, ty) + loss += l + acc += a + for (s, p, g) in zip(net.param_specs(), net.param_values(), grads): + opt.apply_with_lr(epoch, get_lr(epoch), g, p, str(s.name)) + # update progress bar + utils.update_progress(b * 1.0 / num_train_batch, + 'training loss = %f, accuracy = %f' % (l, a)) + info = '\ntraining loss = %f, training accuracy = %f' \ + % (loss / num_train_batch, acc / num_train_batch) + print info + + loss, acc = 0.0, 0.0 + for b in range(num_test_batch): + x = test_x[b * batch_size: (b + 1) * batch_size] + y = test_y[b * batch_size: (b + 1) * batch_size] + tx.copy_from_numpy(x) + ty.copy_from_numpy(y) + l, a = net.evaluate(tx, ty) + loss += l + acc += a + + print 'test loss = %f, test accuracy = %f' \ + % (loss / num_test_batch, acc / num_test_batch) + net.save('model.bin') # save model params into checkpoint file + +if __name__ == '__main__': + data_dir = 'cifar-10-batches-py' + assert os.path.exists(data_dir), \ + 'Pls download the cifar10 dataset via "download_data.py py"' + net = vgg.create_net() + train(data_dir, net) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/examples/cifar10/vgg-parallel.cc ---------------------------------------------------------------------- diff --git a/examples/cifar10/vgg-parallel.cc b/examples/cifar10/vgg-parallel.cc new file mode 100644 index 0000000..ba308e9 --- /dev/null +++ b/examples/cifar10/vgg-parallel.cc @@ -0,0 +1,333 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#include "cifar10.h" +#include "singa/model/feed_forward_net.h" +#include "singa/model/optimizer.h" +#include "singa/model/updater.h" +#include "singa/model/initializer.h" +#include "singa/model/metric.h" +#include "singa/utils/channel.h" +#include "singa/utils/string.h" +#include "singa/core/memory.h" +#include "../../src/model/layer/cudnn_convolution.h" +#include "../../src/model/layer/cudnn_activation.h" +#include "../../src/model/layer/cudnn_pooling.h" +#include "../../src/model/layer/cudnn_lrn.h" +#include "../../src/model/layer/cudnn_dropout.h" +#include "../../src/model/layer/cudnn_batchnorm.h" +#include "../../src/model/layer/dense.h" +#include "../../src/model/layer/flatten.h" +#include <thread> +#include <memory> +#include <cmath> + +namespace singa { + +const float default_wd = 0.0005f; + +LayerConf GenConvConf(string name, int nb_filter, int kernel, int stride, + int pad, float std = .02f, float bias = .0f) { + LayerConf conf; + conf.set_name(name); + conf.set_type("CudnnConvolution"); + ConvolutionConf *conv = conf.mutable_convolution_conf(); + conv->set_num_output(nb_filter); + conv->add_kernel_size(kernel); + conv->add_stride(stride); + conv->add_pad(pad); + conv->set_bias_term(true); + + ParamSpec *wspec = conf.add_param(); + wspec->set_name(name + "_weight"); + auto wfill = wspec->mutable_filler(); + wfill->set_type("Gaussian"); + wfill->set_std(sqrt(2.0f/(nb_filter*9.0f))); + + ParamSpec *bspec = conf.add_param(); + bspec->set_name(name + "_bias"); + auto bfill = bspec->mutable_filler(); + bfill->set_value(bias); + // bspec->set_lr_mult(2); + // bspec->set_decay_mult(0); + return conf; +} + +LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride, + int pad) { + LayerConf conf; + conf.set_name(name); + conf.set_type("CudnnPooling"); + PoolingConf *pool = conf.mutable_pooling_conf(); + pool->set_kernel_size(kernel); + pool->set_stride(stride); + pool->set_pad(pad); + if (!max_pool) pool->set_pool(PoolingConf_PoolMethod_AVE); + return conf; +} + +LayerConf GenReLUConf(string name) { + LayerConf conf; + conf.set_name(name); + conf.set_type("RELU"); + return conf; +} + +LayerConf GenDenseConf(string name, int num_output, float std, float wd = default_wd) { + LayerConf conf; + conf.set_name(name); + conf.set_type("Dense"); + DenseConf *dense = conf.mutable_dense_conf(); + dense->set_num_output(num_output); + + ParamSpec *wspec = conf.add_param(); + wspec->set_name(name + "_weight"); + wspec->set_decay_mult(wd); + auto wfill = wspec->mutable_filler(); + wfill->set_type("Gaussian"); + wfill->set_std(std); + + ParamSpec *bspec = conf.add_param(); + bspec->set_name(name + "_bias"); + bspec->set_lr_mult(2); + bspec->set_decay_mult(0); + + return conf; +} + +LayerConf GenFlattenConf(string name) { + LayerConf conf; + conf.set_name(name); + conf.set_type("Flatten"); + return conf; +} + +LayerConf GenBatchNormConf(string name) { + LayerConf conf; + conf.set_name(name); + conf.set_type("CudnnBatchNorm"); + ParamSpec *gammaspec = conf.add_param(); + gammaspec->set_name(name + "_gamma"); + auto gammafill = gammaspec->mutable_filler(); + gammafill->set_type("uniform"); + gammafill->set_min(0); + gammafill->set_max(1); + + ParamSpec *betaspec = conf.add_param(); + betaspec->set_name(name + "_beta"); + auto betafill = betaspec->mutable_filler(); + betafill->set_type("constant"); + betafill->set_value(0); + + ParamSpec *meanspec = conf.add_param(); + meanspec->set_name(name + "_mean"); + auto meanfill = meanspec->mutable_filler(); + meanfill->set_type("constant"); + meanfill->set_value(0); + + ParamSpec *varspec = conf.add_param(); + varspec->set_name(name + "_var"); + auto varfill = varspec->mutable_filler(); + varfill->set_type("constant"); + varfill->set_value(1); + + return conf; +} + +LayerConf GenDropoutConf(string name, float dropout_ratio) { + LayerConf conf; + conf.set_name(name); + conf.set_type("CudnnDropout"); + DropoutConf *dropout = conf.mutable_dropout_conf(); + dropout->set_dropout_ratio(dropout_ratio); + + return conf; +} + +void ConvBNReLU(FeedForwardNet& net, string name, int nb_filter, Shape* shape = nullptr) { + net.Add(new CudnnConvolution(), GenConvConf(name+"_conv", nb_filter, 3, 1, 1), shape); + net.Add(new CudnnBatchNorm(), GenBatchNormConf(name+"_bn")); + net.Add(new CudnnActivation(), GenReLUConf(name+"_relu")); +} + +FeedForwardNet CreateNet() { + FeedForwardNet net; + Shape s{3, 32, 32}; + ConvBNReLU(net, "conv1_1", 64, &s); + net.Add(new CudnnDropout(), GenDropoutConf("drop1", 0.3)); + ConvBNReLU(net, "conv1_2", 64); + net.Add(new CudnnPooling(), GenPoolingConf("pool1", true, 2, 2, 0)); + ConvBNReLU(net, "conv2_1", 128); + net.Add(new CudnnDropout(), GenDropoutConf("drop2", 0.4)); + ConvBNReLU(net, "conv2_2", 128); + net.Add(new CudnnPooling(), GenPoolingConf("pool2", true, 2, 2, 0)); + ConvBNReLU(net, "conv3_1", 256); + net.Add(new CudnnDropout(), GenDropoutConf("drop3_1", 0.4)); + ConvBNReLU(net, "conv3_2", 256); + net.Add(new CudnnDropout(), GenDropoutConf("drop3_2", 0.4)); + ConvBNReLU(net, "conv3_3", 256); + net.Add(new CudnnPooling(), GenPoolingConf("pool3", true, 2, 2, 0)); + ConvBNReLU(net, "conv4_1", 512); + net.Add(new CudnnDropout(), GenDropoutConf("drop4_1", 0.4)); + ConvBNReLU(net, "conv4_2", 512); + net.Add(new CudnnDropout(), GenDropoutConf("drop4_2", 0.4)); + ConvBNReLU(net, "conv4_3", 512); + net.Add(new CudnnPooling(), GenPoolingConf("pool4", true, 2, 2, 0)); + ConvBNReLU(net, "conv5_1", 512); + net.Add(new CudnnDropout(), GenDropoutConf("drop5_1", 0.4)); + ConvBNReLU(net, "conv5_2", 512); + net.Add(new CudnnDropout(), GenDropoutConf("drop5_2", 0.4)); + ConvBNReLU(net, "conv5_3", 512); + net.Add(new CudnnPooling(), GenPoolingConf("pool5", true, 2, 2, 0)); + net.Add(new Flatten(), GenFlattenConf("flat")); + net.Add(new CudnnDropout(), GenDropoutConf("flat_drop", 0.5)); + net.Add(new Dense(), GenDenseConf("ip1", 512, 0.02)); + net.Add(new CudnnBatchNorm(), GenBatchNormConf("ip1_bn")); + net.Add(new CudnnActivation(), GenReLUConf("ip1_relu")); + net.Add(new CudnnDropout(), GenDropoutConf("ip1_drop", 0.5)); + net.Add(new Dense(), GenDenseConf("ip2", 10, 0.02)); + + return net; +} + +void Train(float lr, int num_epoch, string data_dir) { + Cifar10 data(data_dir); + Tensor train_x, train_y, test_x, test_y; + Tensor train_x_1, train_x_2, train_y_1, train_y_2; + { + auto train = data.ReadTrainData(); + size_t nsamples = train.first.shape(0); + auto mtrain = + Reshape(train.first, Shape{nsamples, train.first.Size() / nsamples}); + const Tensor &mean = Average(mtrain, 0); + SubRow(mean, &mtrain); + Tensor std = Square(mtrain); + std = Average(std, 0); + std = Sqrt(std);; + std += 1e-6f; + DivRow(std, &mtrain); + + train_x = Reshape(mtrain, train.first.shape()); + train_y = train.second; + + LOG(INFO) << "Slicing training data..."; + train_x_1.Reshape(Shape{nsamples / 2, train.first.shape(1), + train.first.shape(2), train.first.shape(3)}); + LOG(INFO) << "Copying first data slice..."; + CopyDataToFrom(&train_x_1, train_x, train_x.Size() / 2); + train_x_2.Reshape(Shape{nsamples / 2, train.first.shape(1), + train.first.shape(2), train.first.shape(3)}); + LOG(INFO) << "Copying second data slice..."; + CopyDataToFrom(&train_x_2, train_x, train_x.Size() / 2, 0, + train_x.Size() / 2); + train_y_1.Reshape(Shape{nsamples / 2}); + train_y_1.AsType(kInt); + LOG(INFO) << "Copying first label slice..."; + CopyDataToFrom(&train_y_1, train_y, train_y.Size() / 2); + train_y_2.Reshape(Shape{nsamples / 2}); + train_y_2.AsType(kInt); + LOG(INFO) << "Copying second label slice..."; + CopyDataToFrom(&train_y_2, train_y, train_y.Size() / 2, 0, + train_y.Size() / 2); + + auto test = data.ReadTestData(); + nsamples = test.first.shape(0); + auto mtest = + Reshape(test.first, Shape{nsamples, test.first.Size() / nsamples}); + SubRow(mean, &mtest); + DivRow(std, &mtest); + test_x = Reshape(mtest, test.first.shape()); + test_y = test.second; + } + + CHECK_EQ(train_x.shape(0), train_y.shape(0)); + CHECK_EQ(test_x.shape(0), test_y.shape(0)); + LOG(INFO) << "Total Training samples = " << train_y.shape(0) + << ", Total Test samples = " << test_y.shape(0); + CHECK_EQ(train_x_1.shape(0), train_y_1.shape(0)); + LOG(INFO) << "On net 1, Training samples = " << train_y_1.shape(0) + << ", Test samples = " << test_y.shape(0); + CHECK_EQ(train_x_2.shape(0), train_y_2.shape(0)); + LOG(INFO) << "On net 2, Training samples = " << train_y_2.shape(0); + + auto net_1 = CreateNet(); + auto net_2 = CreateNet(); + + SGD sgd; + OptimizerConf opt_conf; + opt_conf.set_momentum(0.9); + auto reg = opt_conf.mutable_regularizer(); + reg->set_coefficient(0.0005); + sgd.Setup(opt_conf); + sgd.SetLearningRateGenerator([lr](int epoch) { + return 0.01f / static_cast<float>(1u << (epoch/30)); + }); + + SoftmaxCrossEntropy loss_1, loss_2; + Accuracy acc_1, acc_2; + /// Create updater aggregating gradient on CPU + std::shared_ptr<Updater> updater = std::make_shared<LocalUpdater>(2, &sgd); + + /// Only need to register parameter once. + net_1.Compile(true, true, updater, &loss_1, &acc_1); + net_2.Compile(true, false, updater, &loss_2, &acc_2); + + MemPoolConf mem_conf; + mem_conf.add_device(0); + mem_conf.add_device(1); + std::shared_ptr<DeviceMemPool> mem_pool(new CnMemPool(mem_conf)); + std::shared_ptr<CudaGPU> cuda_1(new CudaGPU(0, mem_pool)); + std::shared_ptr<CudaGPU> cuda_2(new CudaGPU(1, mem_pool)); + net_1.ToDevice(cuda_1); + net_2.ToDevice(cuda_2); + + train_x_1.ToDevice(cuda_1); + train_y_1.ToDevice(cuda_1); + test_x.ToDevice(cuda_1); + test_y.ToDevice(cuda_1); + train_x_2.ToDevice(cuda_2); + train_y_2.ToDevice(cuda_2); + + LOG(INFO) << "Launching thread..."; + std::thread t1 = + net_1.TrainThread(50, num_epoch, train_x_1, train_y_1, test_x, test_y); + std::thread t2 = net_2.TrainThread(50, num_epoch, train_x_2, train_y_2); + t1.join(); + t2.join(); +} +} + +int main(int argc, char **argv) { + singa::InitChannel(nullptr); + int pos = singa::ArgPos(argc, argv, "-epoch"); + int nEpoch = 1; + if (pos != -1) nEpoch = atoi(argv[pos + 1]); + pos = singa::ArgPos(argc, argv, "-lr"); + float lr = 0.001; + if (pos != -1) lr = atof(argv[pos + 1]); + pos = singa::ArgPos(argc, argv, "-data"); + string data = "cifar-10-batches-bin"; + if (pos != -1) data = argv[pos + 1]; + + LOG(INFO) << "Start training"; + singa::Train(lr, nEpoch, data); + LOG(INFO) << "End training"; +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/examples/cifar10/vgg.py ---------------------------------------------------------------------- diff --git a/examples/cifar10/vgg.py b/examples/cifar10/vgg.py new file mode 100644 index 0000000..8063307 --- /dev/null +++ b/examples/cifar10/vgg.py @@ -0,0 +1,52 @@ +import sys +import os + +sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python')) +from singa import layer +from singa import metric +from singa import loss +from singa import net as ffnet +from singa.proto import core_pb2 + +def ConvBnReLU(net, name, nb_filers, sample_shape=None): + net.add(layer.Conv2D(name + '_1', nb_filers, 3, 1, pad=1, + input_sample_shape=sample_shape)) + net.add(layer.BatchNormalization(name + '_2')) + net.add(layer.Activation(name + '_3')) + +def create_net(): + net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy()) + ConvBnReLU(net, 'conv1_1', 64, (3, 32, 32)) + net.add(layer.Dropout('drop1', 0.3, engine='cudnn')) + ConvBnReLU(net, 'conv1_2', 64) + net.add(layer.MaxPooling2D('pool1', 2, 2, border_mode='valid')) + ConvBnReLU(net, 'conv2_1', 128) + net.add(layer.Dropout('drop2_1', 0.4, engine='cudnn')) + ConvBnReLU(net, 'conv2_2', 128) + net.add(layer.MaxPooling2D('pool2', 2, 2, border_mode='valid')) + ConvBnReLU(net, 'conv3_1', 256) + net.add(layer.Dropout('drop3_1', 0.4, engine='cudnn')) + ConvBnReLU(net, 'conv3_2', 256) + net.add(layer.Dropout('drop3_2', 0.4, engine='cudnn')) + ConvBnReLU(net, 'conv3_3', 256) + net.add(layer.MaxPooling2D('pool3', 2, 2, border_mode='valid')) + ConvBnReLU(net, 'conv4_1', 512) + net.add(layer.Dropout('drop4_1', 0.4, engine='cudnn')) + ConvBnReLU(net, 'conv4_2', 512) + net.add(layer.Dropout('drop4_2', 0.4, engine='cudnn')) + ConvBnReLU(net, 'conv4_3', 512) + net.add(layer.MaxPooling2D('pool4', 2, 2, border_mode='valid')) + ConvBnReLU(net, 'conv5_1', 512) + net.add(layer.Dropout('drop5_1', 0.4, engine='cudnn')) + ConvBnReLU(net, 'conv5_2', 512) + net.add(layer.Dropout('drop5_2', 0.4, engine='cudnn')) + ConvBnReLU(net, 'conv5_3', 512) + net.add(layer.MaxPooling2D('pool5', 2, 2, border_mode='valid')) + net.add(layer.Flatten('flat')) + net.add(layer.Dropout('drop_flat', 0.5, engine='cudnn')) + net.add(layer.Dense('ip1', 512)) + net.add(layer.BatchNormalization('batchnorm_ip1')) + net.add(layer.Activation('relu_ip1')) + net.add(layer.Dropout('drop_ip2', 0.5, engine='cudnn')) + net.add(layer.Dense('ip2', 10)) + return net http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/src/core/tensor/tensor.cc ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index 4972a86..c16bd29 100644 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -80,11 +80,11 @@ void Tensor::ResetLike(const Tensor &in) { if (block_ == nullptr || device_ != in.device_ || MemSize() != in.MemSize()) { if (block_ != nullptr && block_->DecRefCount() == 0) device_->FreeBlock(block_); - shape_ = in.shape_; device_ = in.device_; data_type_ = in.data_type_; block_ = device_->NewBlock(in.MemSize()); } + shape_ = in.shape_; } void Tensor::Reshape(const Shape &shape) { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/src/model/layer/batchnorm.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/batchnorm.cc b/src/model/layer/batchnorm.cc index b6edc9e..6ea9f2a 100644 --- a/src/model/layer/batchnorm.cc +++ b/src/model/layer/batchnorm.cc @@ -27,8 +27,18 @@ void BatchNorm::Setup(const Shape& in_sample, const LayerConf& conf) { out_sample_shape_ = in_sample; factor_ = conf.batchnorm_conf().factor(); channels_ = in_sample.at(0); - height_ = in_sample.at(1); - width_ = in_sample.at(2); + if (in_sample.size() == 3u) + height_ = in_sample.at(1); + else + height_ = 1; + if (in_sample.size() == 3u) + width_ = in_sample.at(2); + else + width_ = 1; + if (in_sample.size() == 1u) + is_2d_ = true; + else + is_2d_ = false; bnScale_.Reshape(Shape{channels_ * height_ * width_}); bnBias_.ResetLike(bnScale_); @@ -92,7 +102,8 @@ const Tensor BatchNorm::Forward(int flag, const Tensor& input) { AddRow(bnBias_, &output); } - output.Reshape(Shape{output.shape(0), channels_, height_, width_}); + if (!is_2d_) + output.Reshape(Shape{output.shape(0), channels_, height_, width_}); return output; } @@ -170,10 +181,16 @@ const std::pair<Tensor, vector<Tensor>> BatchNorm::Backward( SumRows(dy, &dbnBias_); param_grad.push_back(dbnScale_); param_grad.push_back(dbnBias_); + Tensor dummy; + dummy.ResetLike(runningMean_); + dummy.SetValue(.0f); + param_grad.push_back(dummy); + param_grad.push_back(dummy); } else { LOG(ERROR) << "Do not call backward for evaluation phase"; } - dx.Reshape(Shape{dx.shape(0), channels_, height_, width_}); + if (!is_2d_) + dx.Reshape(Shape{dx.shape(0), channels_, height_, width_}); return std::make_pair(dx, param_grad); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/src/model/layer/batchnorm.h ---------------------------------------------------------------------- diff --git a/src/model/layer/batchnorm.h b/src/model/layer/batchnorm.h index 6ff818b..f3d83ab 100644 --- a/src/model/layer/batchnorm.h +++ b/src/model/layer/batchnorm.h @@ -44,7 +44,7 @@ class BatchNorm : public Layer { /// \copydoc Layer::Backward(int, const Tensor&, const Tensor&); const std::pair<Tensor, vector<Tensor>> Backward( int flag, const Tensor& grad) override; - const std::vector<Tensor> param_values() override { + virtual const std::vector<Tensor> param_values() override { return std::vector<Tensor> { bnScale_, bnBias_, runningMean_, runningVariance_ }; } @@ -77,6 +77,7 @@ class BatchNorm : public Layer { protected: float factor_; size_t channels_, height_, width_; + bool is_2d_ = false; Tensor bnScale_, bnBias_; Tensor dbnScale_, dbnBias_; Tensor runningMean_, runningVariance_; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/src/model/layer/cudnn_batchnorm.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_batchnorm.cc b/src/model/layer/cudnn_batchnorm.cc index 9e1e892..461f1b6 100644 --- a/src/model/layer/cudnn_batchnorm.cc +++ b/src/model/layer/cudnn_batchnorm.cc @@ -75,14 +75,20 @@ const Tensor CudnnBatchNorm::Forward(int flag, const Tensor& input) { auto shape = input.shape(); auto dtype = input.data_type(); Tensor output; + Tensor x; + if(is_2d_) + x = Reshape(input, Shape{shape.at(0), shape.at(1), 1, 1}); + else + x = input; + shape = x.shape(); if (!has_init_cudnn_) InitCudnn(shape, dtype); // TODO(wangji): check device id of input and params - output.ResetLike(input); + output.ResetLike(x); if ((flag & kTrain) == kTrain) { output.device()->Exec( [=](Context* ctx) { - Block *inBlock = input.block(), *outBlock = output.block(), + Block *inBlock = x.block(), *outBlock = output.block(), *saveMeanBlock = resultSaveMean_.block(), *saveVarBlock = resultSaveVariance_.block(), *runningMeanBlock = runningMean_.block(), @@ -110,7 +116,7 @@ const Tensor CudnnBatchNorm::Forward(int flag, const Tensor& input) { saveMeanBlock->mutable_data(), saveVarBlock->mutable_data())); }, - {input.block(), + {x.block(), bnScale_.block(), bnBias_.block()}, {output.block(), @@ -118,11 +124,11 @@ const Tensor CudnnBatchNorm::Forward(int flag, const Tensor& input) { runningVariance_.block(), resultSaveMean_.block(), resultSaveVariance_.block()}); - buf_.push(input); + buf_.push(x); } else { output.device()->Exec( [=](Context* ctx) { - Block *inBlock = input.block(), *outBlock = output.block(), + Block *inBlock = x.block(), *outBlock = output.block(), *runningMeanBlock = runningMean_.block(), *runningVarBlock = runningVariance_.block(), *bnScaleBlock = bnScale_.block(), @@ -145,13 +151,15 @@ const Tensor CudnnBatchNorm::Forward(int flag, const Tensor& input) { runningVarBlock->data(), epsilon)); }, - {input.block(), + {x.block(), bnScale_.block(), bnBias_.block(), runningMean_.block(), runningVariance_.block()}, {output.block()}); } + if (is_2d_) + output.Reshape(Shape{shape.at(0), shape.at(1)}); return output; } @@ -160,13 +168,13 @@ const std::pair<Tensor, vector<Tensor>> CudnnBatchNorm::Backward( vector <Tensor> param_grad; Tensor dx; if ((flag & kTrain) == kTrain) { - Tensor input = buf_.top(); + Tensor x = buf_.top(); buf_.pop(); dx.ResetLike(grad); dx.device()->Exec( [=](Context* ctx) { Block *dyblock = grad.block(), *dxblock = dx.block(), - *xblock = input.block(), + *xblock = x.block(), *bnScaleBlock = bnScale_.block(), *dbnScaleBlock = dbnScale_.block(), *dbnBiasBlock = dbnBias_.block(), @@ -208,6 +216,13 @@ const std::pair<Tensor, vector<Tensor>> CudnnBatchNorm::Backward( } param_grad.push_back(dbnScale_); param_grad.push_back(dbnBias_); + Tensor dummy; + dummy.ResetLike(dbnScale_); + dummy.SetValue(.0f); + param_grad.push_back(dummy); + param_grad.push_back(dummy); + if (is_2d_) + dx.Reshape(Shape{dx.shape().at(0), dx.shape().at(1)}); return std::make_pair(dx, param_grad); } } // namespace http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/src/model/updater/local_updater.cc ---------------------------------------------------------------------- diff --git a/src/model/updater/local_updater.cc b/src/model/updater/local_updater.cc index eab4a7c..c3c6793 100644 --- a/src/model/updater/local_updater.cc +++ b/src/model/updater/local_updater.cc @@ -33,6 +33,7 @@ void LocalUpdater::Register(const string& name, const ParamSpec& specs) { } dev_index_[name] = 0; to_updater_finished_[name] = 0; + mtx_[name]; } void LocalUpdater::Apply(int step, const string& name, Tensor& grad, http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/src/python/singa/layer.py ---------------------------------------------------------------------- diff --git a/src/python/singa/layer.py b/src/python/singa/layer.py index 937a7e1..a443e1a 100644 --- a/src/python/singa/layer.py +++ b/src/python/singa/layer.py @@ -327,10 +327,16 @@ class BatchNormalization(Layer): beta_specs['name'] = name + '_beta' if 'name' not in gamma_specs: gamma_specs['name'] = name + '_gamma' - self.conf.param.extend([_construct_param_specs_from_dict(beta_specs)]) + mean_specs = {'init': 'constant', 'value': 0, 'name': name+'_mean'} + var_specs = {'init': 'constant', 'value': 1, 'name': name+'_var'} self.conf.param.extend([_construct_param_specs_from_dict(gamma_specs)]) - self.param_specs.append(_construct_param_specs_from_dict(beta_specs)) + self.conf.param.extend([_construct_param_specs_from_dict(beta_specs)]) + self.conf.param.extend([_construct_param_specs_from_dict(mean_specs)]) + self.conf.param.extend([_construct_param_specs_from_dict(var_specs)]) self.param_specs.append(_construct_param_specs_from_dict(gamma_specs)) + self.param_specs.append(_construct_param_specs_from_dict(beta_specs)) + self.param_specs.append(_construct_param_specs_from_dict(mean_specs)) + self.param_specs.append(_construct_param_specs_from_dict(var_specs)) _check_engine(engine, ['cudnn']) self.layer = _create_layer(engine, 'BatchNorm') if input_sample_shape is not None: http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bc3b74b3/src/python/singa/net.py ---------------------------------------------------------------------- diff --git a/src/python/singa/net.py b/src/python/singa/net.py index 084db4b..c0ba61d 100644 --- a/src/python/singa/net.py +++ b/src/python/singa/net.py @@ -64,6 +64,9 @@ class FeedForwardNet(object): specs.extend(lyr.param_specs) return specs + def param_names(self): + return [spec.name for spec in self.param_specs()] + def train(self, x, y): out = self.forward(kTrain, x) l = self.loss.forward(kTrain, out, y) @@ -89,9 +92,10 @@ class FeedForwardNet(object): return tensor.softmax(xx) def forward(self, flag, x): + #print x.l1() for lyr in self.layers: x = lyr.forward(flag, x) - # print lyr.name, x.l1() + # print lyr.name, x.l1() return x def backward(self, flag=kTrain):
