SINGA-204 Support the training of feed-forward neural nets Implement Alexnet model for Cifar10 https://code.google.com/p/cuda-convnet/ But the test accuracy is low 0.72 (which should be 0.82).
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/71eb059c Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/71eb059c Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/71eb059c Branch: refs/heads/dev Commit: 71eb059cd13ea41e74195c7c115f927aaf143490 Parents: cf1d841 Author: Wei Wang <[email protected]> Authored: Mon Jun 27 01:21:59 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Mon Jun 27 15:29:05 2016 +0800 ---------------------------------------------------------------------- examples/cifar10/alexnet.cc | 123 ++++++++++++++++++--------- examples/cifar10/cifar10.h | 3 +- examples/cifar10/make.sh | 2 +- include/singa/core/tensor.h | 1 + include/singa/model/feed_forward_net.h | 2 +- include/singa/model/initializer.h | 26 +++++- include/singa/model/loss.h | 12 ++- include/singa/model/metric.h | 2 +- include/singa/model/optimizer.h | 18 ++-- include/singa/utils/string.h | 11 +++ src/core/tensor/math_kernel.cu | 22 +++++ src/core/tensor/math_kernel.h | 2 + src/core/tensor/tensor.cc | 69 ++++++++++----- src/core/tensor/tensor_math.h | 5 ++ src/core/tensor/tensor_math_cpp.h | 14 +++ src/core/tensor/tensor_math_cuda.h | 9 ++ src/model/feed_forward_net.cc | 104 ++++++++++++++-------- src/model/layer/cudnn_convolution.cc | 5 +- src/model/layer/cudnn_dropout.cc | 2 +- src/model/layer/dense.cc | 4 +- src/model/loss/mse.cc | 5 +- src/model/loss/softmax_cross_entropy.cc | 11 ++- src/model/metric/accuracy.cc | 1 + src/model/optimizer/optimizer.cc | 30 ++++++- src/model/optimizer/sgd.cc | 1 + src/proto/model.proto | 5 ++ test/singa/test_cross_entropy.cc | 8 +- test/singa/test_dense.cc | 2 +- test/singa/test_mse.cc | 8 +- test/singa/test_tensor_math.cc | 4 +- 30 files changed, 370 insertions(+), 141 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/examples/cifar10/alexnet.cc ---------------------------------------------------------------------- diff --git a/examples/cifar10/alexnet.cc b/examples/cifar10/alexnet.cc index 45d8571..d6541a3 100644 --- a/examples/cifar10/alexnet.cc +++ b/examples/cifar10/alexnet.cc @@ -28,12 +28,13 @@ #include "../../src/model/layer/cudnn_convolution.h" #include "../../src/model/layer/cudnn_activation.h" #include "../../src/model/layer/cudnn_pooling.h" +#include "../../src/model/layer/cudnn_lrn.h" #include "../../src/model/layer/dense.h" #include "../../src/model/layer/flatten.h" namespace singa { LayerConf GenConvConf(string name, int nb_filter, int kernel, int stride, - int pad) { + int pad, float std) { LayerConf conf; conf.set_name(name); conf.set_type("CudnnConvolution"); @@ -42,13 +43,23 @@ LayerConf GenConvConf(string name, int nb_filter, int kernel, int stride, conv->add_kernel_size(kernel); conv->add_stride(stride); conv->add_pad(pad); + conv->set_bias_term(true); - FillerConf *weight = conv->mutable_weight_filler(); - weight->set_type("Xavier"); + ParamSpec *wspec = conf.add_param(); + wspec->set_name(name + "_weight"); + auto wfill = wspec->mutable_filler(); + wfill->set_type("Gaussian"); + wfill->set_std(std); + + ParamSpec *bspec = conf.add_param(); + bspec->set_name(name + "_bias"); + bspec->set_lr_mult(2); +// bspec->set_decay_mult(0); return conf; } -LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride, int pad) { +LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride, + int pad) { LayerConf conf; conf.set_name(name); conf.set_type("CudnnPooling"); @@ -56,8 +67,7 @@ LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride, int pool->set_kernel_size(kernel); pool->set_stride(stride); pool->set_pad(pad); - if (!max_pool) - pool->set_pool(PoolingConf_PoolMethod_AVE); + if (!max_pool) pool->set_pool(PoolingConf_PoolMethod_AVE); return conf; } @@ -68,21 +78,38 @@ LayerConf GenReLUConf(string name) { return conf; } -LayerConf GenDenseConf(string name, int num_output) { +LayerConf GenDenseConf(string name, int num_output, float std, float wd) { LayerConf conf; conf.set_name(name); conf.set_type("Dense"); DenseConf *dense = conf.mutable_dense_conf(); dense->set_num_output(num_output); - FillerConf *weight = dense->mutable_weight_filler(); - weight->set_type("Xavier"); + FillerConf *bias = dense->mutable_bias_filler(); + + ParamSpec *wspec = conf.add_param(); + wspec->set_name(name + "_weight"); + wspec->set_decay_mult(wd); + + auto wfill = wspec->mutable_filler(); + wfill->set_type("Gaussian"); + wfill->set_std(std); + + ParamSpec *bspec = conf.add_param(); + bspec->set_name(name + "_bias"); + bspec->set_lr_mult(2); + bspec->set_decay_mult(0); + return conf; } -LayerConf GenSoftmaxConf(string name) { +LayerConf GenLRNConf(string name) { LayerConf conf; conf.set_name(name); - conf.set_type("CudnnSoftmax"); + conf.set_type("CudnnLRN"); + LRNConf *lrn = conf.mutable_lrn_conf(); + lrn->set_local_size(3); + lrn->set_alpha(5e-05); + lrn->set_beta(0.75); return conf; } @@ -92,25 +119,25 @@ LayerConf GenFlattenConf(string name) { conf.set_type("Flatten"); return conf; } -FeedForwardNet CreateNet(Optimizer* opt, Loss<Tensor>* loss, Metric<Tensor>* metric) { + +FeedForwardNet CreateNet() { FeedForwardNet net; Shape s{3, 32, 32}; - net.Add(new CudnnConvolution(), GenConvConf("conv1", 32, 5, 1, 2), &s); + net.Add(new CudnnConvolution(), GenConvConf("conv1", 32, 5, 1, 2, 0.0001), + &s); net.Add(new CudnnActivation(), GenReLUConf("relu1")); - net.Add(new CudnnPooling, GenPoolingConf("pool1", true, 3, 2, 0)); - net.Add(new CudnnConvolution(), GenConvConf("conv2", 32, 5, 1, 2)); + net.Add(new CudnnPooling(), GenPoolingConf("pool1", true, 3, 2, 1)); + net.Add(new CudnnLRN(), GenLRNConf("lrn1")); + net.Add(new CudnnConvolution(), GenConvConf("conv2", 32, 5, 1, 2, 0.01)); net.Add(new CudnnActivation(), GenReLUConf("relu2")); - net.Add(new CudnnPooling(), GenPoolingConf("pool2", true, 3, 2, 0)); - net.Add(new CudnnConvolution, GenConvConf("conv3", 64, 5, 1, 2)); + net.Add(new CudnnPooling(), GenPoolingConf("pool2", false, 3, 2, 1)); + net.Add(new CudnnLRN(), GenLRNConf("lrn2")); + net.Add(new CudnnConvolution, GenConvConf("conv3", 64, 5, 1, 2, 0.01)); net.Add(new CudnnActivation(), GenReLUConf("relu3")); - net.Add(new CudnnConvolution(), GenConvConf("pool3", true, 3, 2, 0)); + net.Add(new CudnnPooling(), GenPoolingConf("pool3", false, 3, 2, 1)); net.Add(new Flatten(), GenFlattenConf("flat")); - net.Add(new Dense(), GenDenseConf("ip1", 10)); - OptimizerConf opt_conf; - opt_conf.set_momentum(0.9); - opt->Setup(opt_conf); - net.Compile(true, opt, loss, metric); + net.Add(new Dense(), GenDenseConf("ip", 10, 0.01, 250)); return net; } @@ -120,50 +147,62 @@ void Train(float lr, int num_epoch, string data_dir) { { auto train = data.ReadTrainData(); size_t nsamples = train.first.shape(0); - auto matx = Reshape(train.first, Shape{nsamples, train.first.Size() / nsamples}); + auto matx = + Reshape(train.first, Shape{nsamples, train.first.Size() / nsamples}); const auto mean = Average(matx, 0); SubRow(mean, &matx); train_x = Reshape(matx, train.first.shape()); train_y = train.second; auto test = data.ReadTestData(); nsamples = test.first.shape(0); - auto maty = Reshape(test.first, Shape{nsamples, test.first.Size() / nsamples}); + auto maty = + Reshape(test.first, Shape{nsamples, test.first.Size() / nsamples}); SubRow(mean, &maty); test_x = Reshape(maty, test.first.shape()); test_y = test.second; } - LOG(ERROR) << "creating net"; + LOG(INFO) << "Training samples = " << train_y.shape(0) + << " Test samples =" << test_y.shape(0); + auto net = CreateNet(); + SGD sgd; + OptimizerConf opt_conf; + opt_conf.set_momentum(0.9); + auto reg = opt_conf.mutable_regularizer(); + reg->set_coefficient(0.004); + sgd.Setup(opt_conf); + sgd.SetLearningRateGenerator([lr](int step) { + if (step <= 120) + return 0.001; + else if (step <= 130) + return 0.0001; + else if (step <= 140) + return 0.00001; + }); SoftmaxCrossEntropy loss; Accuracy acc; - SGD sgd; - sgd.SetLearningRateGenerator([lr](int step) {return lr;}); - auto net = CreateNet(&sgd, &loss, &acc); - + net.Compile(true, &sgd, &loss, &acc); auto cuda = std::make_shared<CudaGPU>(); net.ToDevice(cuda); train_x.ToDevice(cuda); train_y.ToDevice(cuda); - net.Train(50, num_epoch, train_x, train_y); // test_x, test_y); + test_x.ToDevice(cuda); + test_y.ToDevice(cuda); + net.Train(100, num_epoch, train_x, train_y, test_x, test_y); } - - } -int main(int argc, char** argv) { +int main(int argc, char **argv) { singa::InitChannel(nullptr); int pos = singa::ArgPos(argc, argv, "-epoch"); - int nEpoch = 5; - if (pos != -1) - nEpoch = atoi(argv[pos + 1]); + int nEpoch = 140; + if (pos != -1) nEpoch = atoi(argv[pos + 1]); pos = singa::ArgPos(argc, argv, "-lr"); - float lr = 0.01; - if (pos != -1) - lr = atof(argv[pos + 1]); + float lr = 0.001; + if (pos != -1) lr = atof(argv[pos + 1]); pos = singa::ArgPos(argc, argv, "-data"); string data = "cifar-10-batches-bin"; - if (pos != -1) - data = argv[pos + 1]; + if (pos != -1) data = argv[pos + 1]; LOG(INFO) << "Start training"; singa::Train(lr, nEpoch, data); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/examples/cifar10/cifar10.h ---------------------------------------------------------------------- diff --git a/examples/cifar10/cifar10.h b/examples/cifar10/cifar10.h index 261c048..7f10153 100644 --- a/examples/cifar10/cifar10.h +++ b/examples/cifar10/cifar10.h @@ -40,11 +40,12 @@ class Cifar10 { const std::pair<Tensor, Tensor> ReadFile(string file, bool shuffle = false); void ReadImage(std::ifstream* file, int* label, char* buffer); + private: const size_t kImageSize = 32; const size_t kImageVol = 3072; const size_t kBatchSize = 10000; - const size_t kTrainFiles = 1; + const size_t kTrainFiles = 5; string dir_path_; bool normalize_; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/examples/cifar10/make.sh ---------------------------------------------------------------------- diff --git a/examples/cifar10/make.sh b/examples/cifar10/make.sh index 17e4b39..5a41612 100755 --- a/examples/cifar10/make.sh +++ b/examples/cifar10/make.sh @@ -1 +1 @@ -g++ -g --std=c++11 alexnet.cc -o alexnet -I../../include -I../../build/include -I/home/wangwei/local/cudnn4/include -I/home/wangwei/local/include -I/usr/local/cuda/include/ -I../../lib/cnmem/include -L../../build/lib/ -lsinga_core -lsinga_model -lsinga_utils -lcudart -lcublas -lcurand -lcudnn -L/usr/local/cuda/lib64 -L/home/wangwei/local/cudnn4/lib64 ../../build/lib/libproto.a -lprotobuf +g++ -g --std=c++11 alexnet.cc -o alexnet -I../../include -I../../build/include -I/home/wangwei/local/cudnn5/include -I/home/wangwei/local/include -I/usr/local/cuda/include/ -I../../lib/cnmem/include -L../../build/lib/ -lsinga_core -lsinga_model -lsinga_utils -lcudart -lcublas -lcurand -lcudnn -L/home/wangwei/local/cudnn5/lib64 -L/usr/local/cuda/lib64 ../../build/lib/libproto.a -lprotobuf http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/include/singa/core/tensor.h ---------------------------------------------------------------------- diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h index 3b496d9..18aa7ef 100644 --- a/include/singa/core/tensor.h +++ b/include/singa/core/tensor.h @@ -180,6 +180,7 @@ class Tensor { template <typename SType> Tensor &operator/=(const SType x); + float L1() const; float L2() const; protected: http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/include/singa/model/feed_forward_net.h ---------------------------------------------------------------------- diff --git a/include/singa/model/feed_forward_net.h b/include/singa/model/feed_forward_net.h index 9beeb7a..1ca417c 100644 --- a/include/singa/model/feed_forward_net.h +++ b/include/singa/model/feed_forward_net.h @@ -72,7 +72,7 @@ class FeedForwardNet { void Train(size_t batchsize, int nb_epoch, const Tensor& x, const Tensor& y, const Tensor& val_x, const Tensor& val_y); /// Train the neural net over one batch of training data. - const std::pair<float, float> TrainOnBatch(const Tensor& x, const Tensor& y); + const std::pair<float, float> TrainOnBatch(int epoch, const Tensor& x, const Tensor& y); /// Evaluate the neural net with given data. /// Returns one tensor for loss values and one tensor for metric values; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/include/singa/model/initializer.h ---------------------------------------------------------------------- diff --git a/include/singa/model/initializer.h b/include/singa/model/initializer.h index 302fc97..7024f70 100644 --- a/include/singa/model/initializer.h +++ b/include/singa/model/initializer.h @@ -21,8 +21,8 @@ #include <string> #include "singa/core/tensor.h" #include "singa/proto/model.pb.h" +#include "singa/utils/string.h" namespace singa { -namespace init { /// Base class for initializing parameter values. using InitializerConf = FillerConf; class Initializer { @@ -40,6 +40,7 @@ class Initializer { virtual void Fill(Tensor* t) = 0; }; +namespace init { class Constant : public Initializer { public: Constant() = default; @@ -76,7 +77,7 @@ public: void Fill(Tensor* t) override { singa::Gaussian(mean_, std_, t); } private: - float mean_ = 0, std_ = 0.01; + float mean_ = 0, std_ = 1; }; /// Ref: [Bengio and Glorot 2010] Understanding the difficulty of training deep @@ -86,6 +87,7 @@ public: void Fill(Tensor* t) override { CHECK_EQ(t->nDim(), 2u); float scale = sqrt(6.0f / (t->shape(0) + t->shape(1))); + LOG(INFO) << "xavier scale " << scale; singa::Uniform(-scale, scale, t); } }; @@ -100,6 +102,26 @@ class MSRA : public Initializer { singa::Gaussian(0.0f, std, t); } }; + } // namespace init + +std::shared_ptr<Initializer> CreateInitializer(const InitializerConf& conf) { + std::shared_ptr<Initializer> init; + if (ToLowerCase(conf.type()) == "constant") { + init = std::make_shared<init::Constant>(); + } else if (ToLowerCase(conf.type()) == "uniform") { + init = std::make_shared<init::Uniform>(); + } else if (ToLowerCase(conf.type()) == "gaussian") { + init = std::make_shared<init::Gaussian>(); + } else if (ToLowerCase(conf.type()) == "xavier") { + init = std::make_shared<init::Xavier>(); + } else if (ToLowerCase(conf.type()) == "msra") { + init = std::make_shared<init::MSRA>(); + } else { + LOG(FATAL) << "Unknown initialization type: " << conf.type(); + } + init->Setup(conf); + return init; +} } // namespace singa #endif // SINGA_MODEL_INITIALIZER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/include/singa/model/loss.h ---------------------------------------------------------------------- diff --git a/include/singa/model/loss.h b/include/singa/model/loss.h index 41ec701..f400768 100644 --- a/include/singa/model/loss.h +++ b/include/singa/model/loss.h @@ -43,13 +43,13 @@ class Loss { /// Compute the loss values for each sample/instance given the prediction /// and the target. - virtual Tensor Forward(const Tensor& prediction, const T& target) = 0; + virtual Tensor Forward(int flag, const Tensor& prediction, const T& target) = 0; /// Average loss values for all samples in the mini-batch /// It calls Forward() internally. The calling pattern should be /// [Evaluate|Forward] Backward. - float Evaluate(const Tensor& prediction, const T& target) { - Tensor loss = Forward(prediction, target); + float Evaluate(int flag, const Tensor& prediction, const T& target) { + Tensor loss = Forward(flag, prediction, target); loss.ToHost(); return Sum<float>(loss) / (1.0f * loss.Size()); } @@ -68,7 +68,7 @@ class MSE : public Loss<Tensor> { /// and the target, which is 0.5/||prediction-target||^2 /// Users can call Average(const Tensor&) to get the average /// loss value over all samples in the batch. - Tensor Forward(const Tensor& prediction, const Tensor& target) override; + Tensor Forward(int flag, const Tensor& prediction, const Tensor& target) override; /// Compute the gradients of the loss values w.r.t. the prediction, /// which is (prediction-target)/batchsize @@ -90,7 +90,7 @@ class SoftmaxCrossEntropy : public Loss<Tensor> { /// from Softmax(prediction). /// Users can call Average(const Tensor&) to get the average /// loss value over all samples in the batch. - Tensor Forward(const Tensor& prediction, const Tensor& target) override; + Tensor Forward(int flag, const Tensor& prediction, const Tensor& target) override; /// Compute the gradients of the loss values w.r.t. the prediction, /// which is: p[idx] - 1 if idx is the truth category's index; else, @@ -106,5 +106,3 @@ class SoftmaxCrossEntropy : public Loss<Tensor> { } // namespace singa #endif // SINGA_MODEL_LOSS_H_ - - http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/include/singa/model/metric.h ---------------------------------------------------------------------- diff --git a/include/singa/model/metric.h b/include/singa/model/metric.h index d013fa4..b100435 100644 --- a/include/singa/model/metric.h +++ b/include/singa/model/metric.h @@ -48,7 +48,7 @@ class Metric { /// Comptue the metric value averaged over all samples (in a batch) float Evaluate(const Tensor& prediction, const T& target) { - const Tensor& metric = Forward(prediction, target); + const Tensor metric = Forward(prediction, target); return Sum<float>(metric) / (1.0f * metric.Size()); } }; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/include/singa/model/optimizer.h ---------------------------------------------------------------------- diff --git a/include/singa/model/optimizer.h b/include/singa/model/optimizer.h index a268126..2ec68fe 100644 --- a/include/singa/model/optimizer.h +++ b/include/singa/model/optimizer.h @@ -41,7 +41,7 @@ class Regularizer; class Optimizer { public: Optimizer() = default; - + virtual ~Optimizer(); /// Setup the optimzier using configurations from serialized string (for /// binding languages). void Setup(const string& str) { @@ -51,7 +51,7 @@ class Optimizer { } /// Setup the meta fields of the optimizer - virtual void Setup(const OptimizerConf& conf) {} + virtual void Setup(const OptimizerConf& conf); /// Register the parameter, e.g., create Constraint and Regularizers. /// If there is no constraint or regularizer, then no need to register the /// parameter. @@ -76,15 +76,21 @@ class Optimizer { void SetLearningRateGenerator(function<float(int)> func) { learning_rate_generator_ = func; } - /// Since Optimizer base layer has pure virtual function, a virtual - /// deconstructor is needed. - virtual ~Optimizer() = default; + float GetLearningRate(int step) { + if (learning_rate_generator_) + return learning_rate_generator_(step); + else + return 0; + } protected: function<float(int)> learning_rate_generator_; std::unordered_map<std::string, float> learning_rate_multplier_; + std::unordered_map<std::string, float> weight_decay_multplier_; std::unordered_map<std::string, Constraint*> constraints_; std::unordered_map<std::string, Regularizer*> regularizers_; + Constraint* constraint_ = nullptr; + Regularizer* regularizer_ = nullptr; }; /// Apply constraints for parameters (gradient). @@ -141,7 +147,7 @@ class Regularizer { /// e.g., clip each gradient if it is too large w.r.t the threshold, /// \ref /// https://www.reddit.com/r/MachineLearning/comments/31b6x8/gradient_clipping_rnns/ - void Apply(int step, Tensor* grad, Tensor* value); + void Apply(int step, Tensor* grad, Tensor* value, float scale = 1.0f); /// Apply the regularizer for multiple parameter objects together. /// \ref https://github.com/Lasagne/Lasagne/blob/master/lasagne/updates.py void Apply(int step, const vector<Tensor*>& grads, http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/include/singa/utils/string.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/string.h b/include/singa/utils/string.h index cbfb28b..b4c7c24 100644 --- a/include/singa/utils/string.h +++ b/include/singa/utils/string.h @@ -51,6 +51,17 @@ inline int ArgPos(int argc, char** arglist, const char* arg) { return -1; } +template<typename T> +inline std::string VecToStr(const std::vector<T> & in) { + std::string out = "("; + + for (auto x : in) { + out += std::to_string(x) + ", "; + } + out += ")"; + return out; +} + /** * Tokenize a string. * http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/core/tensor/math_kernel.cu ---------------------------------------------------------------------- diff --git a/src/core/tensor/math_kernel.cu b/src/core/tensor/math_kernel.cu index 4135ab8..13005af 100644 --- a/src/core/tensor/math_kernel.cu +++ b/src/core/tensor/math_kernel.cu @@ -265,6 +265,19 @@ __global__ void KernelLT(const size_t num, const float *in, const float x, out[idx] = in[idx] < x ? 1.0f : 0.0f; } } + +__global__ void KernelRowMax(const size_t nrow, const size_t ncol, const float *inPtr, + float *outPtr) { + for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < nrow; + idx += blockDim.x * gridDim.x) { + int offset = idx * ncol; + float maxval = inPtr[offset]; + for (size_t k = 1; k < ncol; k++) { + maxval = max(maxval, inPtr[offset + k]); + } + outPtr[idx] = maxval; + } +} __global__ void KernelComputeCrossEntropy(const size_t batchsize, const size_t dim, const float *p, const int *t, float *loss) { @@ -286,6 +299,9 @@ __global__ void KernelSoftmaxCrossEntropyBwd(const size_t batchsize, grad[pos] = p[pos] - 1.0f; // TODO(wangwei) Consider p and grad are diff } } + + + // ******************************** // Functions call kernels // ******************************** @@ -421,6 +437,12 @@ void SoftmaxCrossEntropyBwd(size_t batchsize, const size_t dim, const float *p, KernelSoftmaxCrossEntropyBwd <<<ceil(batchsize / CU1DBLOCKF), CU1DBLOCKF>>> (batchsize, dim, p, t, grad); } + +void RowMax(const size_t nrow, const size_t ncol, const float *inPtr, + float *outPtr, cudaStream_t stream) { + KernelRowMax <<<ceil(nrow / CU1DBLOCKF), CU1DBLOCKF>>>(nrow, ncol, inPtr, outPtr); +} + /* void square_grad(int n, const float *in, float *out, cudaStream_t s) { kernel_square_grad <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (in, out, n); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/core/tensor/math_kernel.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/math_kernel.h b/src/core/tensor/math_kernel.h index d4087e5..63b0d82 100644 --- a/src/core/tensor/math_kernel.h +++ b/src/core/tensor/math_kernel.h @@ -98,6 +98,8 @@ void SoftmaxCrossEntropyBwd(const size_t batchsize, const size_t dim, const float *p, const int *t, float *grad, cudaStream_t stream); +void RowMax(const size_t nrow, const size_t ncol, const float *inPtr, + float *outPtr, cudaStream_t stream); } // cuda } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/core/tensor/tensor.cc ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index 898cdc6..b07a23c 100644 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -42,7 +42,8 @@ Tensor::Tensor(Shape &&shape, DataType dtype) device_ = defaultDevice; block_ = device_->NewBlock(Product(shape_) * SizeOf(data_type_)); } -Tensor::Tensor(const Shape &shape, std::shared_ptr<Device> device, DataType dtype) +Tensor::Tensor(const Shape &shape, std::shared_ptr<Device> device, + DataType dtype) : data_type_(dtype), device_(device), shape_(shape) { block_ = device_->NewBlock(Product(shape_) * SizeOf(data_type_)); } @@ -68,11 +69,10 @@ Tensor::Tensor(Tensor &&in) in.block_ = nullptr; } -void Tensor::SetBlock(Block* block) { +void Tensor::SetBlock(Block *block) { LOG(WARNING) << "Pls avoid using this function, which may have side-effect."; if (block_ != nullptr) - if (block_->DecRefCount()) - device_->FreeBlock(block_); + if (block_->DecRefCount()) device_->FreeBlock(block_); block_ = block; } @@ -118,8 +118,7 @@ void Tensor::ToDevice(std::shared_ptr<Device> dst) { // TODO(wangwei) the comparison is very strict. May compare against device ID? if (device_ != dst) { Tensor tmp(shape_, dst, data_type_); - if (block_ != nullptr && Size()) - tmp.CopyData(*this); + if (block_ != nullptr && Size()) tmp.CopyData(*this); if (block_ != nullptr && block_->DecRefCount() == 0) device_->FreeBlock(block_); block_ = tmp.block_; @@ -132,13 +131,13 @@ void Tensor::ToHost() { ToDevice(device_->host()); } template <typename DType> void Tensor::CopyDataFromHostPtr(const DType *src, const size_t num, - const size_t offset) { + const size_t offset) { CHECK_EQ(sizeof(DType), SizeOf(data_type_)) << "data_type is " << DataType_Name(data_type_) << " user given type is of size " << sizeof(DType); if (src != nullptr) { device_->CopyDataFromHostPtr(block(), src, sizeof(DType) * num, - sizeof(DType) * offset); + sizeof(DType) * offset); } else { LOG(WARNING) << "Copy data from null host ptr"; } @@ -161,8 +160,7 @@ void Tensor::CopyData(const Tensor &src) { } Tensor Tensor::Clone(std::shared_ptr<Device> device) const { - if (device == nullptr) - device = device_; + if (device == nullptr) device = device_; Tensor t(shape_, device_, data_type_); t.transpose_ = transpose_; t.CopyData(*this); @@ -244,8 +242,6 @@ GenUnaryScalarArgMemberFn(operator+=, Add); GenUnaryScalarArgMemberFn(operator*=, EltwiseMult); GenUnaryScalarArgMemberFn(operator/=, Div); - - // ====================Tensor Operations======================================= void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num, const size_t dst_offset, const size_t src_offset) { @@ -336,6 +332,18 @@ void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num, } while (0) // =============Element-wise operations==================================== +float Tensor::L1() const { + float nrm = 0.0f; + TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, { + device_->Exec([&nrm, this](Context *ctx) { + DType ret; + Asum<DType, Lang>(this->Size(), this->block(), &ret, ctx); + nrm = TypeCast<DType, float>(ret); + }, {this->block()}, {}); + }); + return nrm / Size(); +} + /// L2 norm, Do not use Nrm2 (name conflict). float Tensor::L2() const { float nrm = 0.0f; @@ -346,8 +354,10 @@ float Tensor::L2() const { nrm = TypeCast<DType, float>(ret); }, {this->block()}, {}); }); - return nrm; + return nrm / Size(); } + + template <typename SType> void Tensor::SetValue(const SType x) { CHECK_EQ(sizeof(SType), SizeOf(data_type_)); @@ -525,18 +535,35 @@ Tensor SoftMax(const Tensor &in) { return out; } +Tensor RowMax(const Tensor &in) { + Tensor ret({in.shape(0)}, in.device(), in.data_type()); + TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, { + in.device()->Exec([in, ret](Context *ctx) { + size_t nrow = 1; + if (in.nDim() > 1) nrow = in.shape(0); + size_t ncol = in.Size() / nrow; + RowMax<DType, Lang>(nrow, ncol, in.block(), ret.block(), ctx); + }, {in.block()}, {ret.block()}); + }); + return ret; +} + void SoftMax(const Tensor &in, Tensor *out) { CHECK_LE(in.nDim(), 2u); - Exp(in, out); + out->CopyData(in); size_t nrow = 1, ncol = in.Size(), size = ncol; if (in.nDim() == 2u) { nrow = in.shape(0); ncol = size / nrow; out->Reshape(Shape{nrow, ncol}); } - Tensor sum(Shape{nrow}, in.device(), in.data_type()); - SumColumns(*out, &sum); - DivColumn(sum, out); + Tensor tmp = RowMax(*out); + SubColumn(tmp, out); + Exp(*out, out); + + SumColumns(*out, &tmp); + DivColumn(tmp, out); + out->Reshape(in.shape()); } void AddColumn(const Tensor &v, Tensor *M) { AddColumn(1, 1, v, M); } @@ -582,8 +609,8 @@ void AddRow(const SType alpha, const SType beta, const Tensor &v, Tensor *M) { Mult(alpha, one, vmat, beta, M); } } -template -void AddRow(const float alpha, const float beta, const Tensor &v, Tensor *M); +template void AddRow(const float alpha, const float beta, const Tensor &v, + Tensor *M); /// Divide column 'v' by each column of matrix M; write results into 'out' void DivColumn(const Tensor &v, Tensor *M) { @@ -699,7 +726,7 @@ void MultRow(const Tensor &v, Tensor *M) { }); } -Tensor SliceRows(const Tensor& in, const size_t start, const size_t end) { +Tensor SliceRows(const Tensor &in, const size_t start, const size_t end) { LOG(FATAL) << "Tensor::SliceRows is not implemented"; Tensor ret; /* @@ -788,6 +815,7 @@ void Gaussian(const SType mean, const SType std, Tensor *out) { template void Gaussian<float>(const float mean, const float std, Tensor *out); // ================Blas operations============================================ + template <typename SType> void Axpy(const SType alpha, const Tensor &in, Tensor *out) { TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, { @@ -869,5 +897,4 @@ void SoftmaxCrossEntropyBwd(const Tensor &t, Tensor *p) { }); } - } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/core/tensor/tensor_math.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h index 57ccb88..7732dd2 100644 --- a/src/core/tensor/tensor_math.h +++ b/src/core/tensor/tensor_math.h @@ -339,6 +339,11 @@ void SoftmaxCrossEntropyBwd(const size_t batchsize, const size_t dim, LOG(FATAL) << "Not Implemented"; } +template <typename DType, typename Lang> +void RowMax(const size_t nrow, const size_t ncol, const Block *in, + const Block *ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} // ************************************** // Matrix functions // ************************************** http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/core/tensor/tensor_math_cpp.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h index 4717b5f..3e0c8ad 100644 --- a/src/core/tensor/tensor_math_cpp.h +++ b/src/core/tensor/tensor_math_cpp.h @@ -549,6 +549,20 @@ void SoftmaxCrossEntropyBwd<float, lang::Cpp>(const size_t batchsize, } } +template <> +void RowMax<float, lang::Cpp>(const size_t nrow, const size_t ncol, + const Block *in, const Block *out, Context *ctx) { + const float *inPtr = static_cast<const float *>(in->data()); + float *outPtr = static_cast<float *>(out->mutable_data()); + for (size_t r = 0; r < nrow; r++) { + int offset = r * ncol; + float maxval = inPtr[offset]; + for (size_t c = 1; c < ncol; c++) + maxval = std::max(maxval, inPtr[offset + c]); + outPtr[r] = maxval; + } +} + // =========Matrix operations ================================================ /* template <> http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/core/tensor/tensor_math_cuda.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h index 67ee861..43bfa1b 100644 --- a/src/core/tensor/tensor_math_cuda.h +++ b/src/core/tensor/tensor_math_cuda.h @@ -421,6 +421,15 @@ void SoftmaxCrossEntropyBwd<float, lang::Cuda>(const size_t batchsize, cuda::SoftmaxCrossEntropyBwd(batchsize, dim, pPtr, tPtr, gradPtr, ctx->stream); } + +template <> +void RowMax<float, lang::Cuda>(const size_t nrow, const size_t ncol, + const Block* in, const Block* out, + Context* ctx) { + const float* inPtr = static_cast<const float*>(in->data()); + float* outPtr = static_cast<float*>(out->mutable_data()); + cuda::RowMax(nrow, ncol, inPtr, outPtr, ctx->stream); +} } // namespace singa #endif // USE_CUDA http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/model/feed_forward_net.cc ---------------------------------------------------------------------- diff --git a/src/model/feed_forward_net.cc b/src/model/feed_forward_net.cc index a24d36a..e682918 100644 --- a/src/model/feed_forward_net.cc +++ b/src/model/feed_forward_net.cc @@ -1,22 +1,26 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ #include "singa/model/feed_forward_net.h" +#include "singa/model/initializer.h" #include "singa/utils/logging.h" #include "singa/utils/channel.h" namespace singa { @@ -37,12 +41,15 @@ Layer* FeedForwardNet::Add(const LayerConf& conf, const Shape* sample_shape) { return layer; } -Layer* FeedForwardNet::Add(Layer* layer, const LayerConf& conf, const Shape* sample_shape) { +Layer* FeedForwardNet::Add(Layer* layer, const LayerConf& conf, + const Shape* sample_shape) { + CHECK(conf.has_name()) << "Must set layer name"; if (sample_shape == nullptr) layer->Setup(layers_.back()->GetOutputSampleShape(), conf); else layer->Setup(*sample_shape, conf); Add(layer); + LOG(INFO) << layer->name() << VecToStr(layer->GetOutputSampleShape()); return layer; } @@ -75,12 +82,19 @@ void FeedForwardNet::Compile(bool shuffle, Optimizer* opt, Loss<Tensor>* loss, opt_ = opt; loss_ = loss; metric_ = metric; - // init params and register them to sgd + const auto specs = GetParamSpecs(); + const auto params = GetParamValues(); + CHECK_EQ(specs.size(), params.size()); + for (size_t k = 0; k < specs.size(); k++) { + opt_->Register(specs[k].name(), specs[k]); + auto init = CreateInitializer(specs[k].filler()); + init->Fill(params[k]); + LOG(INFO) << specs[k].name() << " : " << params[k]->L1(); + } } void FeedForwardNet::ToDevice(std::shared_ptr<Device> device) { - for (auto layer: layers_) - layer->ToDevice(device); + for (auto layer : layers_) layer->ToDevice(device); /* opt_->ToDevice(device); loss_->ToDevice(device); @@ -129,7 +143,6 @@ void FeedForwardNet::Train(size_t batchsize, int nb_epoch, const Tensor& x, void FeedForwardNet::Train(size_t batchsize, int nb_epoch, const Tensor& x, const Tensor& y, const Tensor& val_x, const Tensor& val_y) { - InitNetParams(); CHECK_EQ(x.shape(0), y.shape(0)) << "Diff num of sampels in x and y"; int num_extra_samples = x.shape(0) % batchsize; if (num_extra_samples != 0) @@ -137,13 +150,18 @@ void FeedForwardNet::Train(size_t batchsize, int nb_epoch, const Tensor& x, Channel* train_ch = GetChannel("train_perf"); train_ch->EnableDestStderr(true); Channel* val_ch = GetChannel("val_perf"); + val_ch->EnableDestStderr(true); + std::vector<size_t> index; + for (size_t i = 0; i < x.shape(0) / batchsize; i++) index.push_back(i); for (int epoch = 0; epoch < nb_epoch; epoch++) { + if (shuffle_) std::random_shuffle(index.begin(), index.end()); float loss = 0.0f, metric = 0.0f; size_t b = 0; for (; b < x.shape(0) / batchsize; b++) { - const Tensor bx = CopyRows(x, b * batchsize, b * batchsize + batchsize); - const Tensor by = CopyRows(y, b * batchsize, b * batchsize + batchsize); - const auto ret = TrainOnBatch(bx, by); + size_t idx = index[b]; + const Tensor bx = CopyRows(x, idx * batchsize, (idx + 1) * batchsize); + const Tensor by = CopyRows(y, idx * batchsize, (idx + 1) * batchsize); + const auto ret = TrainOnBatch(epoch, bx, by); loss += ret.first; metric += ret.second; } @@ -151,7 +169,8 @@ void FeedForwardNet::Train(size_t batchsize, int nb_epoch, const Tensor& x, metric /= b; train_ch->Send("Epoch " + std::to_string(epoch) + ", training loss = " + std::to_string(loss) + ", accuracy = " + - std::to_string(metric)); + std::to_string(metric) + ", lr = " + + std::to_string(opt_->GetLearningRate(epoch))); if (val_x.Size() && val_y.Size()) { const auto val_perf = Evaluate(val_x, val_y, batchsize); val_ch->Send("Epoch " + std::to_string(epoch) + ", val loss = " + @@ -162,22 +181,28 @@ void FeedForwardNet::Train(size_t batchsize, int nb_epoch, const Tensor& x, } } -const std::pair<float, float> FeedForwardNet::TrainOnBatch(const Tensor& x, +const std::pair<float, float> FeedForwardNet::TrainOnBatch(int epoch, + const Tensor& x, const Tensor& y) { int flag = kTrain; const Tensor fea = Forward(flag, x); - float loss = loss_->Evaluate(fea, y); + float loss = loss_->Evaluate(flag, fea, y); float metric = metric_->Evaluate(fea, y); const Tensor grad = loss_->Backward(); - const auto grads = Backward(kTrain, grad); + auto grads = Backward(kTrain, grad / static_cast<float>(x.shape(0))); + auto names = GetParamNames(); + auto values = GetParamValues(); + for (size_t k = 0; k < grads.size(); k++) { + opt_->Apply(epoch, names[k], &grads[k], values.at(k)); + } return std::make_pair(loss, metric); } const Tensor FeedForwardNet::Forward(int flag, const Tensor& data) { Tensor input = data, output; for (auto layer : layers_) { -// LOG(INFO) << layer->name(); output = layer->Forward(flag, input); + // LOG(INFO) << layer->name() << ": " << output.L2(); input = output; } return output; @@ -185,13 +210,22 @@ const Tensor FeedForwardNet::Forward(int flag, const Tensor& data) { const vector<Tensor> FeedForwardNet::Backward(int flag, const Tensor& grad) { vector<Tensor> param_grads; + std::stack<Tensor> buf; Tensor tmp = grad; for (int i = layers_.size() - 1; i >= 0; i--) { - // LOG(INFO) << layers_.at(i)->name(); + // LOG(INFO) << layers_.at(i)->name() << " : " << tmp.L2(); auto ret = layers_.at(i)->Backward(flag, tmp); tmp = ret.first; - if (ret.second.size()) - for (const auto x : ret.second) param_grads.push_back(x); + if (ret.second.size()) { + for (int k = ret.second.size() - 1; k >= 0; k--) { + buf.push(ret.second[k]); + // LOG(INFO) << " " << buf.top().L2(); + } + } + } + while (!buf.empty()) { + param_grads.push_back(buf.top()); + buf.pop(); } return param_grads; } @@ -230,8 +264,8 @@ std::pair<Tensor, Tensor> FeedForwardNet::EvaluateOnBatch(const Tensor& x, int flag = kEval; const Tensor fea = Forward(flag, x); const Tensor m = metric_->Forward(fea, y); - const Tensor l = loss_->Forward(fea, y); - return std::make_pair(m, l); + const Tensor l = loss_->Forward(flag, fea, y); + return std::make_pair(l, m); } const Tensor FeedForwardNet::Predict(const Tensor& x, size_t batchsize) { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/model/layer/cudnn_convolution.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc index eb507b2..3dca28a 100644 --- a/src/model/layer/cudnn_convolution.cc +++ b/src/model/layer/cudnn_convolution.cc @@ -72,8 +72,8 @@ void CudnnConvolution::InitCudnn(const Tensor &input) { num_filters_, conv_height_, conv_width_)); if (bias_term_) CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc_, CUDNN_TENSOR_NCHW, - GetCudnnDataType(dtype), 1, 1, - num_filters_, 1)); + GetCudnnDataType(dtype), 1, 1, 1, + num_filters_)); CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc_, pad_h_, pad_w_, stride_h_, stride_w_, 1, 1, CUDNN_CROSS_CORRELATION)); @@ -244,6 +244,7 @@ const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward( }, {grad.block(), weight_.block()}, {dx.block(), workspace_.block()}); param_grad.push_back(dw); param_grad.push_back(db); + LOG(INFO) << "bias nrm " << db.L1(); return std::make_pair(dx, param_grad); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/model/layer/cudnn_dropout.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_dropout.cc b/src/model/layer/cudnn_dropout.cc index f9b9dbf..ab83226 100644 --- a/src/model/layer/cudnn_dropout.cc +++ b/src/model/layer/cudnn_dropout.cc @@ -108,7 +108,7 @@ const std::pair<Tensor, vector<Tensor>> CudnnDropout::Backward( } void CudnnDropout::ToDevice(std::shared_ptr<Device> device) { Dropout::ToDevice(device); - state.ToDevice(device); + state_.ToDevice(device); } } // namespace singa #endif // CUDNN_VERSION_MAJOR>=5 http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/model/layer/dense.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/dense.cc b/src/model/layer/dense.cc index c6a9f8a..338409c 100644 --- a/src/model/layer/dense.cc +++ b/src/model/layer/dense.cc @@ -41,13 +41,15 @@ void Dense::Setup(const Shape& in_sample, const LayerConf &conf) { bias_.Reshape(Shape{hdim_}); param_values_.push_back(&weight_); param_values_.push_back(&bias_); + for (auto specs: conf.param()) + param_specs_.push_back(specs); } /// \copydoc Layer::Forward(int flag, const Tensor&) const Tensor Dense::Forward(int flag, const Tensor &input) { CHECK(buf_.empty()); Tensor output; - CHECK_EQ(input.nDim(), 2); + CHECK_EQ(input.nDim(), 2u); if (transpose_) // use the transposed version of weight_ for computing output = Mult(input, weight_); else http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/model/loss/mse.cc ---------------------------------------------------------------------- diff --git a/src/model/loss/mse.cc b/src/model/loss/mse.cc index a4bbb72..6e19059 100644 --- a/src/model/loss/mse.cc +++ b/src/model/loss/mse.cc @@ -20,7 +20,7 @@ namespace singa { -Tensor MSE::Forward(const Tensor& prediction, const Tensor& target) { +Tensor MSE::Forward(int flag, const Tensor& prediction, const Tensor& target) { CHECK(buf_.empty()) << "Do not call Forward successively for more than twice." << " The calling pattern is [Forward|Evaluate] Backward"; Tensor t = prediction - target; @@ -28,7 +28,8 @@ Tensor MSE::Forward(const Tensor& prediction, const Tensor& target) { if (t.nDim() > 1) batchsize = t.shape().at(0); size_t dim = t.Size() / batchsize; t.Reshape(Shape{batchsize, dim}); - buf_.push(t); + if (kTrain & flag) + buf_.push(t); // TODO(wangwei) use CastType for operator/ return Sum(Square(t), 1) * 0.5f; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/model/loss/softmax_cross_entropy.cc ---------------------------------------------------------------------- diff --git a/src/model/loss/softmax_cross_entropy.cc b/src/model/loss/softmax_cross_entropy.cc index bed3348..3411fbe 100644 --- a/src/model/loss/softmax_cross_entropy.cc +++ b/src/model/loss/softmax_cross_entropy.cc @@ -21,7 +21,7 @@ namespace singa { -Tensor SoftmaxCrossEntropy::Forward(const Tensor& prediction, +Tensor SoftmaxCrossEntropy::Forward(int flag, const Tensor& prediction, const Tensor& target) { CHECK(buf_.empty()) << "Do not call Forward successively for more than twice." << " The calling pattern is [Forward|Evaluate] Backward"; @@ -30,13 +30,17 @@ Tensor SoftmaxCrossEntropy::Forward(const Tensor& prediction, size_t dim = prediction.Size() / batchsize; const Tensor& input = Reshape(prediction, Shape{batchsize, dim}); Tensor prob = SoftMax(input); + // LOG(INFO) << "prob: " << prob.L2(); // buffer intermediate data - buf_.push(prob); - buf_.push(target); + if (flag & kTrain) { + buf_.push(prob); + buf_.push(target); + } Tensor loss(Shape{batchsize}, prob.device(), prob.data_type()); ComputeCrossEntropy(prob, target, &loss); + return loss; } @@ -50,4 +54,3 @@ Tensor SoftmaxCrossEntropy::Backward() { } } // namespace singa - http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/model/metric/accuracy.cc ---------------------------------------------------------------------- diff --git a/src/model/metric/accuracy.cc b/src/model/metric/accuracy.cc index 1b667b1..ffda938 100644 --- a/src/model/metric/accuracy.cc +++ b/src/model/metric/accuracy.cc @@ -30,6 +30,7 @@ Tensor Accuracy::Match(const Tensor& predict, const vector<int>& target) { // TODO(wangwei) CloneToDevice(host); const float* prob = prediction.data<float>(); float* score = new float[batchsize]; + memset(score, 0, batchsize * sizeof(float)); for (size_t b = 0; b < batchsize; b++) { vector<std::pair<float, int>> prob_class; for (size_t c = 0; c < nb_classes; c++) { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/model/optimizer/optimizer.cc ---------------------------------------------------------------------- diff --git a/src/model/optimizer/optimizer.cc b/src/model/optimizer/optimizer.cc index c9e7a72..9be47c8 100644 --- a/src/model/optimizer/optimizer.cc +++ b/src/model/optimizer/optimizer.cc @@ -21,6 +21,17 @@ namespace singa { +Optimizer::~Optimizer() { + for (auto entry : regularizers_) delete entry.second; + for (auto entry : constraints_) delete entry.second; + if (constraint_ != nullptr) delete constraint_; + if (regularizer_ != nullptr) delete regularizer_; +} +void Optimizer::Setup(const OptimizerConf& conf) { + if (conf.has_regularizer()) + regularizer_ = new Regularizer(conf.regularizer()); + if (conf.has_constraint()) constraint_ = new Constraint(conf.constraint()); +} void Optimizer::Register(const string& name, const ParamSpec& specs) { if (specs.has_constraint()) { CHECK(constraints_.find(name) == constraints_.end()) @@ -32,6 +43,11 @@ void Optimizer::Register(const string& name, const ParamSpec& specs) { << "Parameter with name = " << name << " has already registered"; regularizers_[name] = new Regularizer(specs.regularizer()); } + if (specs.has_decay_mult()) { + CHECK(weight_decay_multplier_.find(name) == weight_decay_multplier_.end()) + << "Parameter with name = " << name << " has already registered"; + weight_decay_multplier_[name] = specs.decay_mult(); + } if (specs.has_lr_mult()) { CHECK(learning_rate_multplier_.find(name) == learning_rate_multplier_.end()) << "Parameter with name = " << name << " has already registered"; @@ -47,10 +63,18 @@ void Optimizer::Register(const string& name, const ParamSpec& specs) { void Optimizer::Apply(int step, const string& name, Tensor* grad, Tensor* param) { // TODO(wangwei) need to consider the order of constraint and regularizer - if (regularizers_.find(name) != regularizers_.end()) + if (regularizers_.find(name) != regularizers_.end()) { regularizers_.at(name)->Apply(step, param, grad); + } else if (regularizer_ != nullptr) { + float scale = 1.0f; + if (weight_decay_multplier_.find(name) != weight_decay_multplier_.end()) + scale = weight_decay_multplier_.at(name); + regularizer_->Apply(step, param, grad, scale); + } if (constraints_.find(name) != constraints_.end()) constraints_.at(name)->Apply(step, param, grad); + else if (constraint_ != nullptr) + constraint_->Apply(step, param, grad); float lr = learning_rate_generator_(step); if (learning_rate_multplier_.find(name) != learning_rate_multplier_.end()) lr *= learning_rate_multplier_.at(name); @@ -62,9 +86,9 @@ void Regularizer::Setup(const RegularizerConf& conf) { coefficient_ = conf.coefficient(); } -void Regularizer::Apply(int step, Tensor* value, Tensor* grad) { +void Regularizer::Apply(int step, Tensor* value, Tensor* grad, float scale) { if (type_ == "L2" || type_ == "l2") { - (*grad) -= (*value) * coefficient_; + Axpy(coefficient_ * scale, *value, grad); } else { CHECK(type_ == "NotSet") << "Unknown regularizer type = " << type_; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/model/optimizer/sgd.cc ---------------------------------------------------------------------- diff --git a/src/model/optimizer/sgd.cc b/src/model/optimizer/sgd.cc index a5c66a1..71071ff 100644 --- a/src/model/optimizer/sgd.cc +++ b/src/model/optimizer/sgd.cc @@ -22,6 +22,7 @@ namespace singa { void SGD::Setup(const OptimizerConf& conf) { + Optimizer::Setup(conf); if (conf.has_momentum()) { float m = conf.momentum(); SetMomentumGenerator([m](int step) { return m; }); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/src/proto/model.proto ---------------------------------------------------------------------- diff --git a/src/proto/model.proto b/src/proto/model.proto index c06deec..b1318d9 100644 --- a/src/proto/model.proto +++ b/src/proto/model.proto @@ -89,6 +89,11 @@ message OptimizerConf { // delta is used to avoid dividing zero optional float delta = 6 [default = 1e-8]; + + // global regularizer lower priority than ParamSpec regularizer + optional RegularizerConf regularizer = 10; + // global constraint lower priority than ParamSpec constraint + optional ConstraintConf constraint = 11; } message ConstraintConf { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/test/singa/test_cross_entropy.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_cross_entropy.cc b/test/singa/test_cross_entropy.cc index d73591f..c7fa2fb 100644 --- a/test/singa/test_cross_entropy.cc +++ b/test/singa/test_cross_entropy.cc @@ -44,7 +44,7 @@ TEST_F(TestSoftmaxCrossEntropy, CppForward) { t.CopyDataFromHostPtr(tdat, 2); singa::SoftmaxCrossEntropy cross_entropy; - const Tensor& loss = cross_entropy.Forward(p, t); + const Tensor& loss = cross_entropy.Forward(singa::kEval, p, t); auto ldat = loss.data<float>(); const float result_test = -log(0.25); @@ -58,7 +58,7 @@ TEST_F(TestSoftmaxCrossEntropy, CppBackward) { t.CopyDataFromHostPtr(tdat, 2); singa::SoftmaxCrossEntropy cross_entropy; - cross_entropy.Forward(p, t); + cross_entropy.Forward(singa::kTrain, p, t); const Tensor& grad = cross_entropy.Backward(); auto gdat = grad.data<float>(); @@ -82,7 +82,7 @@ TEST_F(TestSoftmaxCrossEntropy, CudaForward) { p.CopyDataFromHostPtr(pdat, 8); t.CopyDataFromHostPtr(tdat, 2); - Tensor loss = cross_entropy.Forward(p, t); + Tensor loss = cross_entropy.Forward(singa::kEval, p, t); loss.ToHost(); auto ldat = loss.data<float>(); @@ -99,7 +99,7 @@ TEST_F(TestSoftmaxCrossEntropy, CudaBackward) { p.CopyDataFromHostPtr(pdat, 8); t.CopyDataFromHostPtr(tdat, 2); - cross_entropy.Forward(p, t); + cross_entropy.Forward(singa::kTrain, p, t); Tensor grad = cross_entropy.Backward(); grad.ToHost(); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/test/singa/test_dense.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_dense.cc b/test/singa/test_dense.cc index 363fb6e..e80384f 100644 --- a/test/singa/test_dense.cc +++ b/test/singa/test_dense.cc @@ -207,7 +207,7 @@ TEST(Dense, BackwardCuda) { singa::Tensor grad(singa::Shape{batchsize, hdim}, cuda); grad.CopyDataFromHostPtr(dy, batchsize * hdim); - const auto ret = dense.Backward(singa::kTrain, grad); + auto ret = dense.Backward(singa::kTrain, grad); singa::Tensor in_grad = ret.first; singa::Tensor dweight = ret.second.at(0); singa::Tensor dbias = ret.second.at(1); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/test/singa/test_mse.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_mse.cc b/test/singa/test_mse.cc index 788652f..640caf4 100644 --- a/test/singa/test_mse.cc +++ b/test/singa/test_mse.cc @@ -42,7 +42,7 @@ class TestMSE : public ::testing::Test { #ifdef USE_CBLAS TEST_F(TestMSE, CppForward) { singa::MSE mse; - const Tensor& loss = mse.Forward(p, t); + const Tensor& loss = mse.Forward(singa::kEval, p, t); auto ldat = loss.data<float>(); for (size_t i = 0, k = 0; i < loss.Size(); i++) { @@ -57,7 +57,7 @@ TEST_F(TestMSE, CppForward) { TEST_F(TestMSE, CppBackward) { singa::MSE mse; - mse.Forward(p, t); + mse.Forward(singa::kTrain, p, t); const Tensor& grad = mse.Backward(); auto gdat = grad.data<float>(); @@ -72,7 +72,7 @@ TEST_F(TestMSE, CudaForward) { auto dev = std::make_shared<singa::CudaGPU>(); p.ToDevice(dev); t.ToDevice(dev); - Tensor loss = mse->Forward(p, t); + Tensor loss = mse->Forward(singa::kEval, p, t); loss.ToHost(); auto ldat = loss.data<float>(); @@ -94,7 +94,7 @@ TEST_F(TestMSE, CudaBackward) { auto dev = std::make_shared<singa::CudaGPU>(); p.ToDevice(dev); t.ToDevice(dev); - mse.Forward(p, t); + mse.Forward(singa::kTrain, p, t); Tensor grad = mse.Backward(); grad.ToHost(); auto gdat = grad.data<float>(); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/71eb059c/test/singa/test_tensor_math.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_tensor_math.cc b/test/singa/test_tensor_math.cc index f8d0351..2a0df0d 100644 --- a/test/singa/test_tensor_math.cc +++ b/test/singa/test_tensor_math.cc @@ -346,7 +346,7 @@ TEST_F(TestTensorMath, L2Cpp) { float l2 = a.L2(); float target = 0.0f; for (size_t i = 0; i < a.Size(); i++) target += dat1[i] * dat1[i]; - EXPECT_FLOAT_EQ(l2, sqrt(target)); + EXPECT_FLOAT_EQ(l2, sqrt(target) / a.Size()); } TEST_F(TestTensorMath, MultCpp) { const float x[4] = {1.0f, 2.0f, 3.0f, 4.0f}; @@ -514,7 +514,7 @@ TEST_F(TestTensorMath, L2Cuda) { float l2 = t.L2(); float target = 0.0f; for (size_t i = 0; i < t.Size(); i++) target += dat1[i] * dat1[i]; - EXPECT_FLOAT_EQ(l2, sqrt(target)); + EXPECT_FLOAT_EQ(l2, sqrt(target) / t.Size()); } TEST_F(TestTensorMath, MultCuda) { const float x[4] = {1.0f, 2.0f, 3.0f, 4.0f};
