Repository: incubator-singa Updated Branches: refs/heads/master b30d7ea55 -> f134a24e2
SINGA-379 Implement batchnorm operation and its related functions for autograd - format former codes and rename some variables. - fixed some make error Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/10274f3b Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/10274f3b Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/10274f3b Branch: refs/heads/master Commit: 10274f3bf82106595305c58644c69353f4b414a8 Parents: a105b24 Author: xuewanqi <xue_wa...@outlook.com> Authored: Wed Jul 11 02:59:35 2018 +0000 Committer: Wang Wei <wangwei...@gmail.com> Committed: Wed Jul 11 21:57:47 2018 +0800 ---------------------------------------------------------------------- src/api/model_operation.i | 16 +-- src/model/operation/batchnorm.cc | 246 +++++++++++++++++----------------- src/model/operation/batchnorm.h | 48 +++---- 3 files changed, 158 insertions(+), 152 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/10274f3b/src/api/model_operation.i ---------------------------------------------------------------------- diff --git a/src/api/model_operation.i b/src/api/model_operation.i index 783a1f8..95efd26 100755 --- a/src/api/model_operation.i +++ b/src/api/model_operation.i @@ -51,28 +51,28 @@ Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle class BatchNormHandle{ public: - BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, const Tensor& RunningVariance_); + BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance); size_t batchsize; - Tensor runningMean_; - Tensor runningVariance_; + Tensor runningMean; + Tensor runningVariance; }; class CudnnBatchNormHandle: public BatchNormHandle{ public: - CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, const Tensor& RunningVariance_); + CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance); size_t batchsize; - Tensor runningMean_; - Tensor runningVariance_; + Tensor runningMean; + Tensor runningVariance; }; -Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, +Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, std::vector<Tensor>& cache, CudnnBatchNormHandle &cbnh); -Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, const CudnnBatchNormHandle &cbnh); +Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, const CudnnBatchNormHandle &cbnh); std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/10274f3b/src/model/operation/batchnorm.cc ---------------------------------------------------------------------- diff --git a/src/model/operation/batchnorm.cc b/src/model/operation/batchnorm.cc old mode 100644 new mode 100755 index 9b6f9cd..145b90b --- a/src/model/operation/batchnorm.cc +++ b/src/model/operation/batchnorm.cc @@ -1,164 +1,170 @@ #include "./batchnorm.h" -namespace singa{ - -BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, - const Tensor& RunningVariance_){ - factor_ = momentum; - batchsize = input.shape()[0]; - channels_= input.shape()[2]; - if (input.nDim()== 4u){ - height_= input.shape()[3]; - width_=input.shape()[4]; - is_2d_= false; - }else{ - size_t height_ = 1; - size_t width_ = 1; - bool is_2d_ = true; +namespace singa { + +BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, + const Tensor& RunningVariance) { + factor = momentum; + batchsize = input.shape(0); + channels = input.shape(1); + if (input.nDim() == 4u) { + height = input.shape(2); + width = input.shape(3); + is_2d = false; + } else if (input.nDim() == 2u) { + height = 1; + width = 1; + is_2d = true; + } else { + LOG(FATAL) << "The dimension of input should either be 4D or 2D."; } - runningMean_= RunningMean_; - runningVariance_= RunningVariance_; + runningMean = RunningMean; + runningVariance = RunningVariance; }; -CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, - const Tensor& RunningVariance_):BatchNormHandle(momentum, input, RunningMean_, RunningVariance_){ - if (is_2d_) - mode_ = CUDNN_BATCHNORM_PER_ACTIVATION; +CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, + const Tensor& RunningVariance): BatchNormHandle(momentum, input, RunningMean, RunningVariance) { + if (is_2d) + mode = CUDNN_BATCHNORM_PER_ACTIVATION; else - mode_ = CUDNN_BATCHNORM_SPATIAL; - auto dtype = input.data_type(); - CUDNN_CHECK(cudnnCreateTensorDescriptor(&shape_desc_)); - CUDNN_CHECK(cudnnCreateTensorDescriptor(¶m_desc_)); - CUDNN_CHECK(cudnnSetTensor4dDescriptor(shape_desc_, CUDNN_TENSOR_NCHW, - GetCudnnDataType(dtype), batchsize, - channels_, height_, width_)); - CUDNN_CHECK(cudnnSetTensor4dDescriptor(param_desc_, CUDNN_TENSOR_NCHW, - GetCudnnDataType(dtype), 1, channels_, + mode = CUDNN_BATCHNORM_SPATIAL; + DataType dtype = input.data_type(); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&shape_desc)); + CUDNN_CHECK(cudnnCreateTensorDescriptor(¶m_desc)); + CUDNN_CHECK(cudnnSetTensor4dDescriptor(shape_desc, CUDNN_TENSOR_NCHW, + GetCudnnDataType(dtype), + batchsize, + channels, height, width)); + CUDNN_CHECK(cudnnSetTensor4dDescriptor(param_desc, CUDNN_TENSOR_NCHW, + GetCudnnDataType(dtype), 1, channels, 1, 1)); - }; +}; + +Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, + std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) { -Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, - std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) { - - auto shape = x.shape(); + Shape shape = x.shape(); Tensor output; Tensor input; //for unification of 2d and 4d cases. - if (cbnh.is_2d_) + if (cbnh.is_2d) input = Reshape(x, Shape{shape.at(0), shape.at(1), 1, 1}); else input = x; output.ResetLike(x); - Tensor resultSaveMean_; - Tensor resultSaveVariance_; + Tensor resultSaveMean; + Tensor resultSaveVariance; - resultSaveMean_.Reshape(Shape{cbnh.channels_}); - resultSaveVariance_.Reshape(Shape{cbnh.channels_}); + resultSaveMean.Reshape(Shape{cbnh.channels}); + resultSaveVariance.Reshape(Shape{cbnh.channels}); - cache.push_back(resultSaveMean_); - cache.push_back(resultSaveVariance_); - cache.push_back(bnScale_); + cache.push_back(resultSaveMean); + cache.push_back(resultSaveVariance); + cache.push_back(bnScale); //cache={x, mean, var, scale} - output.device()->Exec( - [&output, &input, &bnScale_, &bnBias_, &cache, &cbnh](Context* ctx) { - Block* inBlock = input.block(), * outBlock = output.block(), - * saveMeanBlock = cache[1].block(), - * saveVarBlock = cache[2].block(), - * runningMeanBlock = cbnh.runningMean_.block(), - * runningVarBlock = cbnh.runningVariance_.block(), - * bnScaleBlock = bnScale_.block(), - * bnBiasBlock = bnBias_.block(); - const float alpha = 1.0f, beta = 0.0f; - double epsilon = CUDNN_BN_MIN_EPSILON; - CUDNN_CHECK(cudnnBatchNormalizationForwardTraining( - ctx->cudnn_handle, cbnh.mode_, &alpha, &beta, cbnh.shape_desc_, - inBlock->data(), cbnh.shape_desc_, outBlock->mutable_data(), - cbnh.param_desc_, bnScaleBlock->data(), bnBiasBlock->data(), cbnh.factor_, - runningMeanBlock->mutable_data(), runningVarBlock->mutable_data(), - epsilon, saveMeanBlock->mutable_data(), - saveVarBlock->mutable_data())); - }, - {input.block(), bnScale_.block(), bnBias_.block()}, - {output.block(), cbnh.runningMean_.block(), cbnh.runningVariance_.block(), - cache[1].block(), cache[2].block()}); - if (cbnh.is_2d_) output.Reshape(Shape{shape.at(0), shape.at(1)}); + output.device()->Exec( + [&output, &input, &bnScale, &bnBias, &cache, &cbnh](Context * ctx) { + Block* inBlock = input.block(), * outBlock = output.block(), + * saveMeanBlock = cache[1].block(), + * saveVarBlock = cache[2].block(), + * runningMeanBlock = cbnh.runningMean.block(), + * runningVarBlock = cbnh.runningVariance.block(), + * bnScaleBlock = bnScale.block(), + * bnBiasBlock = bnBias.block(); + const float alpha = 1.0f, beta = 0.0f; + double epsilon = CUDNN_BN_MIN_EPSILON; + CUDNN_CHECK(cudnnBatchNormalizationForwardTraining( + ctx->cudnn_handle, cbnh.mode, &alpha, &beta, cbnh.shape_desc, + inBlock->data(), cbnh.shape_desc, outBlock->mutable_data(), + cbnh.param_desc, bnScaleBlock->data(), bnBiasBlock->data(), cbnh.factor, + runningMeanBlock->mutable_data(), runningVarBlock->mutable_data(), + epsilon, saveMeanBlock->mutable_data(), + saveVarBlock->mutable_data())); + }, + {input.block(), bnScale.block(), bnBias.block()}, + { output.block(), cbnh.runningMean.block(), cbnh.runningVariance.block(), + cache[1].block(), cache[2].block() + }); + if (cbnh.is_2d) output.Reshape(Shape{shape.at(0), shape.at(1)}); return output; }; -Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, - const CudnnBatchNormHandle &cbnh) { - auto shape = x.shape(); +Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, + const CudnnBatchNormHandle &cbnh) { + Shape shape = x.shape(); Tensor output; Tensor input; //for unification of 2d and 4d cases. - if (cbnh.is_2d_) + if (cbnh.is_2d) input = Reshape(x, Shape{shape.at(0), shape.at(1), 1, 1}); else input = x; output.ResetLike(x); - output.device()->Exec( - [&output, &input, &bnScale_, &bnBias_, &cbnh](Context* ctx) { - Block* inBlock = input.block(), * outBlock = output.block(), - * runningMeanBlock = cbnh.runningMean_.block(), - * runningVarBlock = cbnh.runningVariance_.block(), - * bnScaleBlock = bnScale_.block(), - * bnBiasBlock = bnBias_.block(); - const float alpha = 1.0f, beta = 0.0f; - double epsilon = CUDNN_BN_MIN_EPSILON; - CUDNN_CHECK(cudnnBatchNormalizationForwardInference( - ctx->cudnn_handle, cbnh.mode_, &alpha, &beta, cbnh.shape_desc_, - inBlock->data(), cbnh.shape_desc_, outBlock->mutable_data(), - cbnh.param_desc_, bnScaleBlock->data(), bnBiasBlock->data(), - runningMeanBlock->data(), runningVarBlock->data(), epsilon)); - }, - {input.block(), bnScale_.block(), bnBias_.block(), cbnh.runningMean_.block(), - cbnh.runningVariance_.block()}, - {output.block()}); - if (cbnh.is_2d_) output.Reshape(Shape{shape.at(0), shape.at(1)}); + output.device()->Exec( + [&output, &input, &bnScale, &bnBias, &cbnh](Context * ctx) { + Block* inBlock = input.block(), * outBlock = output.block(), + * runningMeanBlock = cbnh.runningMean.block(), + * runningVarBlock = cbnh.runningVariance.block(), + * bnScaleBlock = bnScale.block(), + * bnBiasBlock = bnBias.block(); + const float alpha = 1.0f, beta = 0.0f; + double epsilon = CUDNN_BN_MIN_EPSILON; + CUDNN_CHECK(cudnnBatchNormalizationForwardInference( + ctx->cudnn_handle, cbnh.mode, &alpha, &beta, cbnh.shape_desc, + inBlock->data(), cbnh.shape_desc, outBlock->mutable_data(), + cbnh.param_desc, bnScaleBlock->data(), bnBiasBlock->data(), + runningMeanBlock->data(), runningVarBlock->data(), epsilon)); + }, + { input.block(), bnScale.block(), bnBias.block(), cbnh.runningMean.block(), + cbnh.runningVariance.block() + }, + {output.block()}); + if (cbnh.is_2d) output.Reshape(Shape{shape.at(0), shape.at(1)}); return output; }; -std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh){ +std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) { vector<Tensor> out_grads; Tensor dx; dx.ResetLike(dy); - Tensor dbnScale_; - dbnScale_.ResetLike(cache[3]); + Tensor dbnScale; + dbnScale.ResetLike(cache[3]); - Tensor dbnBias_; - dbnBias_.ResetLike(cache[3]); - //dbnBias_.ResetLike(bnBias_); + Tensor dbnBias; + dbnBias.ResetLike(cache[3]); + //dbnBias.ResetLike(bnBias); dx.device()->Exec( - [&dx, &dbnScale_, &dbnBias_, &dy, &cache, &cbnh](Context* ctx) { - Block* dyblock = dy.block(), * dxblock = dx.block(), - * xblock = cache[0].block(), * bnScaleBlock = cache[3].block(), - * dbnScaleBlock = dbnScale_.block(), - * dbnBiasBlock = dbnBias_.block(), - * saveMeanBlock = cache[1].block(), - * saveVarBlock = cache[2].block(); - const float alpha = 1.0f, beta = .0f; - double epsilon = CUDNN_BN_MIN_EPSILON; - CUDNN_CHECK(cudnnBatchNormalizationBackward( - ctx->cudnn_handle, cbnh.mode_, &alpha, &beta, &alpha, &beta, - cbnh.shape_desc_, xblock->data(), cbnh.shape_desc_, dyblock->data(), - cbnh.shape_desc_, dxblock->mutable_data(), cbnh.param_desc_, - bnScaleBlock->data(), dbnScaleBlock->mutable_data(), - dbnBiasBlock->mutable_data(), epsilon, saveMeanBlock->data(), - saveVarBlock->data())); - }, - {cache[0].block(), dy.block(), cache[3].block(), cache[1].block(), - cache[2].block()}, - {dx.block(), dbnScale_.block(), dbnBias_.block()}); - - if (cbnh.is_2d_) dx.Reshape(Shape{dx.shape().at(0), dx.shape().at(1)}); + [&dx, &dbnScale, &dbnBias, &dy, &cache, &cbnh](Context * ctx) { + Block* dyblock = dy.block(), * dxblock = dx.block(), + * xblock = cache[0].block(), * bnScaleBlock = cache[3].block(), + * dbnScaleBlock = dbnScale.block(), + * dbnBiasBlock = dbnBias.block(), + * saveMeanBlock = cache[1].block(), + * saveVarBlock = cache[2].block(); + const float alpha = 1.0f, beta = .0f; + double epsilon = CUDNN_BN_MIN_EPSILON; + CUDNN_CHECK(cudnnBatchNormalizationBackward( + ctx->cudnn_handle, cbnh.mode, &alpha, &beta, &alpha, &beta, + cbnh.shape_desc, xblock->data(), cbnh.shape_desc, dyblock->data(), + cbnh.shape_desc, dxblock->mutable_data(), cbnh.param_desc, + bnScaleBlock->data(), dbnScaleBlock->mutable_data(), + dbnBiasBlock->mutable_data(), epsilon, saveMeanBlock->data(), + saveVarBlock->data())); + }, + { cache[0].block(), dy.block(), cache[3].block(), cache[1].block(), + cache[2].block() + }, + {dx.block(), dbnScale.block(), dbnBias.block()}); + + if (cbnh.is_2d) dx.Reshape(Shape{dx.shape().at(0), dx.shape().at(1)}); out_grads.push_back(dx); - out_grads.push_back(dbnScale_); - out_grads.push_back(dbnBias_); -return out_grads; + out_grads.push_back(dbnScale); + out_grads.push_back(dbnBias); + return out_grads; }; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/10274f3b/src/model/operation/batchnorm.h ---------------------------------------------------------------------- diff --git a/src/model/operation/batchnorm.h b/src/model/operation/batchnorm.h old mode 100644 new mode 100755 index f2da4cd..f21bd1d --- a/src/model/operation/batchnorm.h +++ b/src/model/operation/batchnorm.h @@ -9,24 +9,24 @@ #include "../layer/cudnn_utils.h" // check_cudnn #endif // USE_CUDNN -namespace singa{ +namespace singa { -class BatchNormHandle{ - public: - BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, const Tensor& RunningVariance_); +class BatchNormHandle { +public: + BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance); - float factor_; - size_t channels_; - size_t batchsize; + float factor; - Tensor runningMean_; - Tensor runningVariance_; + size_t batchsize; + size_t channels; + size_t height; + size_t width; - bool is_2d_ ; - //bool train = true; + Tensor runningMean; + Tensor runningVariance; - size_t height_; - size_t width_; + bool is_2d; + //bool train = true; }; //Tensor CpuBatchNormForwardTraining(); @@ -38,22 +38,22 @@ class BatchNormHandle{ #ifdef USE_CUDNN -class CudnnBatchNormHandle: public BatchNormHandle{ - public: - CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean_, const Tensor& RunningVariance_); +class CudnnBatchNormHandle: public BatchNormHandle { +public: + CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance); - //~CudnnBatchNormHandle(); + //~CudnnBatchNormHandle(); - cudnnBatchNormMode_t mode_; - cudnnTensorDescriptor_t shape_desc_ = nullptr; - cudnnTensorDescriptor_t param_desc_ = nullptr; + cudnnBatchNormMode_t mode; + cudnnTensorDescriptor_t shape_desc = nullptr; + cudnnTensorDescriptor_t param_desc = nullptr; }; -Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, - std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh); +Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, + std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh); -Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale_, const Tensor& bnBias_, - const CudnnBatchNormHandle &cbnh); +Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, + const CudnnBatchNormHandle &cbnh); std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh);