SINGA-379 Implement batchnorm operation and its related functions for autograd
Change the API (arguments) of batchnorm functions. Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/ce1a7335 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/ce1a7335 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/ce1a7335 Branch: refs/heads/master Commit: ce1a73359a6e3eb2c3e7ec5ac861ac1829144dad Parents: 8654f89 Author: Wang Wei <[email protected]> Authored: Thu Jul 12 00:31:40 2018 +0800 Committer: wang wei <[email protected]> Committed: Thu Jul 12 12:32:50 2018 +0800 ---------------------------------------------------------------------- python/singa/autograd.py | 42 +++++----- src/api/model_operation.i | 38 ++++----- src/model/operation/batchnorm.cc | 147 +++++++++++++++------------------- src/model/operation/batchnorm.h | 26 +++--- 4 files changed, 118 insertions(+), 135 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ce1a7335/python/singa/autograd.py ---------------------------------------------------------------------- diff --git a/python/singa/autograd.py b/python/singa/autograd.py index 4ba0b11..3a2eddd 100755 --- a/python/singa/autograd.py +++ b/python/singa/autograd.py @@ -33,7 +33,6 @@ CTensor = singa.Tensor training = False - def infer_dependency(op): ''' Infer the dependency of all operations with the @@ -483,6 +482,7 @@ def cross_entropy(y, t): class SoftMaxCrossEntropy(Operation): + def forward(self, x, t): self.p = singa.SoftMax(x) self.t = t @@ -771,7 +771,7 @@ class Conv2D(Layer): return y -class BatchNorm2d(NewLayer): +class BatchNorm2d(Layer): def __init__(self, num_features, momentum=0.9): self.channels = num_features @@ -787,16 +787,16 @@ class BatchNorm2d(NewLayer): requires_grad=True, stores_grad=True) self.bias.set_value(0.0) - self.runningmean = Tensor( + self.running_mean = Tensor( shape=param_shape, requires_grad=False, stores_grad=False) - self.runningvariance = Tensor( + self.running_var = Tensor( shape=param_shape, requires_grad=False, stores_grad=False) def __call__(self, x): assert x.shape[1] == self.channels, 'number of channels dismatched.' self.device_check(x, self.scale, self.bias, - self.runningmean, self.runningvariance) + self.running_mean, self.running_var) if x.device.id() == -1: raise NotImplementedError @@ -804,39 +804,40 @@ class BatchNorm2d(NewLayer): else: if not hasattr(self, 'handle'): self.handle = singa.CudnnBatchNormHandle( - self.momentum, x.data, self.runningmean.data, self.runningvariance.data) + self.momentum, x.data) elif x.shape[0] != self.handle.batchsize: self.handle = singa.CudnnBatchNormHandle( - self.momentum, x.data, self.runningmean.data, self.runningvariance.data) + self.momentum, x.data) self.handle.device_id = x.device.id() - y = batchnorm2d(x, self.scale, self.bias, self.handle) + y = batchnorm2d(x, self.scale, self.bias, + self.running_mean, self.running_var, self.handle) return y class _BatchNorm2d(Operation): - def __init__(self, handle): + def __init__(self, running_mean, running_var, handle): + self.running_mean = running_mean.data + self.running_var = running_var.data self.handle = handle def forward(self, x, scale, bias): if training: - resultmean = CTensor([scale.shape(0)]) - resultvariance = CTensor([scale.shape(0)]) - self.cache = (x, resultmean, resultvariance, scale) if self.handle.device_id == -1: raise NotImplementedError else: - resultmean.ToDevice(x.device()) - resultvariance.ToDevice(x.device()) - return singa.GpuBatchNormForwardTraining(x, scale, bias, self.cache, self.handle) - + y, mean, var = singa.GpuBatchNormForwardTraining(self.handle, + x, scale, bias, self.running_mean, self.running_var) + self.cache = (x, scale, mean, var) else: if self.handle.device_id == -1: raise NotImplementedError else: - return singa.GpuBatchNormForwardInference(x, scale, bias, self.handle) + y, _, _ = singa.GpuBatchNormForwardInference( + self.handle, x, scale, bias, self.running_mean, self.running_var) + return y def backward(self, dy): assert training is True and hasattr( @@ -848,10 +849,11 @@ class _BatchNorm2d(Operation): if self.handle.device_id == -1: raise NotImplementedError else: + x, scale, mean, var = self.cache dx, ds, db = singa.GpuBatchNormBackward( - dy, self.cache, self.handle) + self.handle, dy, x, scale, mean, var) return dx, ds, db -def batchnorm2d(x, scale, bias, handle): - return _BatchNorm2d(handle)(x, scale, bias)[0] +def batchnorm2d(x, scale, bias, running_mean, running_var, handle): + return _BatchNorm2d(running_mean, running_var, handle)(x, scale, bias)[0] http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ce1a7335/src/api/model_operation.i ---------------------------------------------------------------------- diff --git a/src/api/model_operation.i b/src/api/model_operation.i index a1d59ed..6f2d1fa 100755 --- a/src/api/model_operation.i +++ b/src/api/model_operation.i @@ -27,6 +27,17 @@ Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle &ch); + + +class BatchNormHandle{ + public: + BatchNormHandle(const float momentum, const Tensor& input); + + size_t batchsize; +}; + + + #if USE_CUDNN class CudnnConvHandle: public ConvHandle { public: @@ -47,36 +58,25 @@ Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle &cch); -#endif // USE_CUDNN -class BatchNormHandle{ - public: - BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance); - - size_t batchsize; - Tensor runningMean; - Tensor runningVariance; - -}; class CudnnBatchNormHandle: public BatchNormHandle{ public: - CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance); + CudnnBatchNormHandle(const float momentum, const Tensor& input); size_t batchsize; - Tensor runningMean; - Tensor runningVariance; }; -Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, - const std::vector<Tensor>& cache, CudnnBatchNormHandle &cbnh); +const vector<Tensor> GpuBatchNormForwardTraining(const CudnnBatchNormHandle &cbnh, + const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, Tensor& running_mean, Tensor& running_var); -Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, const CudnnBatchNormHandle &cbnh); +Tensor GpuBatchNormForwardInference(const CudnnBatchNormHandle &cbnh, const Tensor& x, + const Tensor& bnScale, const Tensor& bnBias, const Tensor& running_mean, const Tensor& running_var); -std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, - const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh); +const std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh, + const Tensor& dy, const Tensor& x, const Tensor& bnScale, const Tensor& mean, const Tensor& var); - +#endif // USE_CUDNN } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ce1a7335/src/model/operation/batchnorm.cc ---------------------------------------------------------------------- diff --git a/src/model/operation/batchnorm.cc b/src/model/operation/batchnorm.cc index 6b2421d..7040895 100755 --- a/src/model/operation/batchnorm.cc +++ b/src/model/operation/batchnorm.cc @@ -2,8 +2,7 @@ namespace singa { -BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, - const Tensor& RunningVariance) { +BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input) { factor = momentum; batchsize = input.shape(0); channels = input.shape(1); @@ -18,12 +17,11 @@ BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input, cons } else { LOG(FATAL) << "The dimension of input should either be 4D or 2D."; } - runningMean = RunningMean; - runningVariance = RunningVariance; }; -CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, - const Tensor& RunningVariance): BatchNormHandle(momentum, input, RunningMean, RunningVariance) { +#if USE_CUDNN +CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum, + const Tensor& input): BatchNormHandle(momentum, input) { if (is_2d) mode = CUDNN_BATCHNORM_PER_ACTIVATION; else @@ -40,85 +38,77 @@ CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum, const Tensor& i 1, 1)); }; -Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, - const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) { +Tensor GpuBatchNormForwardTraining(const CudnnBatchNormHandle &cbnh, + const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, + Tensor& running_mean, Tensor& running_var) { CHECK_EQ(x.device()->lang(), kCuda); CHECK_EQ(bnScale.device()->lang(), kCuda); CHECK_EQ(bnBias.device()->lang(), kCuda); - CHECK_EQ(cbnh.runningMean.device()->lang(), kCuda); - CHECK_EQ(cbnh.runningVariance.device()->lang(), kCuda); - CHECK_EQ(cache[1].device()->lang(), kCuda); //resultmean - CHECK_EQ(cache[2].device()->lang(), kCuda); //resultvariance + CHECK_EQ(runningMean.device()->lang(), kCuda); + CHECK_EQ(runningVariance.device()->lang(), kCuda); + + Tensor mean, var; + mean.ResetLike(running_mean); + var.ResetLike(running_var); Shape shape = x.shape(); - Tensor output; - Tensor input; //for unification of 2d and 4d cases. + + Tensor input = x; //for unification of 2d and 4d cases. if (cbnh.is_2d) - input = Reshape(x, Shape{shape.at(0), shape.at(1), 1, 1}); - else - input = x; + input.Reshape(Shape{shape.at(0), shape.at(1), 1, 1}); + + Tensor output; output.ResetLike(x); output.device()->Exec( - [&output, &input, &bnScale, &bnBias, &cache, &cbnh](Context * ctx) { - Block* inBlock = input.block(), * outBlock = output.block(), - * saveMeanBlock = cache[1].block(), - * saveVarBlock = cache[2].block(), - * runningMeanBlock = cbnh.runningMean.block(), - * runningVarBlock = cbnh.runningVariance.block(), - * bnScaleBlock = bnScale.block(), - * bnBiasBlock = bnBias.block(); + [&](Context * ctx) { const float alpha = 1.0f, beta = 0.0f; double epsilon = CUDNN_BN_MIN_EPSILON; CUDNN_CHECK(cudnnBatchNormalizationForwardTraining( ctx->cudnn_handle, cbnh.mode, &alpha, &beta, cbnh.shape_desc, - inBlock->data(), cbnh.shape_desc, outBlock->mutable_data(), - cbnh.param_desc, bnScaleBlock->data(), bnBiasBlock->data(), cbnh.factor, - runningMeanBlock->mutable_data(), runningVarBlock->mutable_data(), - epsilon, saveMeanBlock->mutable_data(), - saveVarBlock->mutable_data())); + input.block()->data(), cbnh.shape_desc, output.block()->mutable_data(), + cbnh.param_desc, bnScale.block()->data(), bnBias.block()->data(), cbnh.factor, + running_mean.block()->mutable_data(), running_var.block()->mutable_data(), + epsilon, mean.block()->mutable_data(), + var.block()->mutable_data())); }, - {input.block(), bnScale.block(), bnBias.block()}, - { output.block(), cbnh.runningMean.block(), cbnh.runningVariance.block(), - cache[1].block(), cache[2].block() + {input.block(), bnScale.block(), bnBias.block(), running_mean.block(), running_var.block()}, { + output.block(), running_mean.block(), running_var.block(), + mean.block(), var.block() }); if (cbnh.is_2d) output.Reshape(Shape{shape.at(0), shape.at(1)}); - return output; + return {output, mean, var}; }; -Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, - const CudnnBatchNormHandle &cbnh) { +Tensor GpuBatchNormForwardInference(const CudnnBatchNormHandle &cbnh, + const Tensor& x, const Tensor& bnScale, + const Tensor& bnBias, const Tensor& running_mean, const Tensor& running_var) { CHECK_EQ(x.device()->lang(), kCuda); CHECK_EQ(bnScale.device()->lang(), kCuda); CHECK_EQ(bnBias.device()->lang(), kCuda); - CHECK_EQ(cbnh.runningMean.device()->lang(), kCuda); - CHECK_EQ(cbnh.runningVariance.device()->lang(), kCuda); + CHECK_EQ(cbnh.running_mean.device()->lang(), kCuda); + CHECK_EQ(cbnh.running_variance.device()->lang(), kCuda); Shape shape = x.shape(); - Tensor output; - Tensor input; //for unification of 2d and 4d cases. + + Tensor input = x; //for unification of 2d and 4d cases. if (cbnh.is_2d) - input = Reshape(x, Shape{shape.at(0), shape.at(1), 1, 1}); - else - input = x; + input.Reshape(Shape{shape.at(0), shape.at(1), 1, 1}); + + Tensor output; output.ResetLike(x); output.device()->Exec( - [&output, &input, &bnScale, &bnBias, &cbnh](Context * ctx) { - Block* inBlock = input.block(), * outBlock = output.block(), - * runningMeanBlock = cbnh.runningMean.block(), - * runningVarBlock = cbnh.runningVariance.block(), - * bnScaleBlock = bnScale.block(), - * bnBiasBlock = bnBias.block(); + [&](Context * ctx) { const float alpha = 1.0f, beta = 0.0f; double epsilon = CUDNN_BN_MIN_EPSILON; CUDNN_CHECK(cudnnBatchNormalizationForwardInference( ctx->cudnn_handle, cbnh.mode, &alpha, &beta, cbnh.shape_desc, - inBlock->data(), cbnh.shape_desc, outBlock->mutable_data(), - cbnh.param_desc, bnScaleBlock->data(), bnBiasBlock->data(), - runningMeanBlock->data(), runningVarBlock->data(), epsilon)); - }, - { input.block(), bnScale.block(), bnBias.block(), cbnh.runningMean.block(), - cbnh.runningVariance.block() + input.block()->data(), cbnh.shape_desc, output.block()->mutable_data(), + cbnh.param_desc, bnScale.block()->data(), bnBias.block()->data(), + running_mean.block()->data(), running_var.block()->data(), epsilon)); + }, { + input.block(), bnScale.block(), bnBias.block(), running_mean.block(), + running_variance.block() }, {output.block()}); if (cbnh.is_2d) output.Reshape(Shape{shape.at(0), shape.at(1)}); @@ -126,52 +116,43 @@ Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, cons }; -std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) { +std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh, + const Tensor& dy, const Tensor& x, const Tensor& bnScale, const Tensor& mean, + const Tensor& var) { CHECK_EQ(dy.device()->lang(), kCuda); - CHECK_EQ(cache[0].device()->lang(), kCuda); - CHECK_EQ(cache[1].device()->lang(), kCuda); - CHECK_EQ(cache[2].device()->lang(), kCuda); - CHECK_EQ(cache[3].device()->lang(), kCuda); + CHECK_EQ(x.device()->lang(), kCuda); + CHECK_EQ(bnScale.device()->lang(), kCuda); + CHECK_EQ(mean.device()->lang(), kCuda); + CHECK_EQ(var.device()->lang(), kCuda); vector<Tensor> out_grads; Tensor dx; dx.ResetLike(dy); Tensor dbnScale; - dbnScale.ResetLike(cache[3]); + dbnScale.ResetLike(bnScale); Tensor dbnBias; - dbnBias.ResetLike(cache[3]); - //dbnBias.ResetLike(bnBias); + dbnBias.ResetLike(bnScale); dx.device()->Exec( - [&dx, &dbnScale, &dbnBias, &dy, &cache, &cbnh](Context * ctx) { - Block* dyblock = dy.block(), * dxblock = dx.block(), - * xblock = cache[0].block(), * bnScaleBlock = cache[3].block(), - * dbnScaleBlock = dbnScale.block(), - * dbnBiasBlock = dbnBias.block(), - * saveMeanBlock = cache[1].block(), - * saveVarBlock = cache[2].block(); + [&](Context * ctx) { + const float alpha = 1.0f, beta = .0f; double epsilon = CUDNN_BN_MIN_EPSILON; CUDNN_CHECK(cudnnBatchNormalizationBackward( ctx->cudnn_handle, cbnh.mode, &alpha, &beta, &alpha, &beta, - cbnh.shape_desc, xblock->data(), cbnh.shape_desc, dyblock->data(), - cbnh.shape_desc, dxblock->mutable_data(), cbnh.param_desc, - bnScaleBlock->data(), dbnScaleBlock->mutable_data(), - dbnBiasBlock->mutable_data(), epsilon, saveMeanBlock->data(), - saveVarBlock->data())); - }, - { cache[0].block(), dy.block(), cache[3].block(), cache[1].block(), - cache[2].block() - }, + cbnh.shape_desc, x.block()->data(), cbnh.shape_desc, dy.block()->data(), + cbnh.shape_desc, dx.block()->mutable_data(), cbnh.param_desc, + bnScale.block()->data(), dbnScale.block()->mutable_data(), + dbnBias.block()->mutable_data(), epsilon, mean.block()->data(), + var.block()->data())); + }, {x.block(), dy.block(), bnScale.block(), mean.block(), var.block()}, {dx.block(), dbnScale.block(), dbnBias.block()}); if (cbnh.is_2d) dx.Reshape(Shape{dx.shape().at(0), dx.shape().at(1)}); - out_grads.push_back(dx); - out_grads.push_back(dbnScale); - out_grads.push_back(dbnBias); - return out_grads; + + return {dx, dbnScale, dbnBias}; }; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ce1a7335/src/model/operation/batchnorm.h ---------------------------------------------------------------------- diff --git a/src/model/operation/batchnorm.h b/src/model/operation/batchnorm.h index ee182f9..f4372e3 100755 --- a/src/model/operation/batchnorm.h +++ b/src/model/operation/batchnorm.h @@ -12,8 +12,8 @@ namespace singa { class BatchNormHandle { -public: - BatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance); + public: + BatchNormHandle(const float momentum, const Tensor& input); float factor; @@ -21,10 +21,6 @@ public: size_t channels; size_t height; size_t width; - - Tensor runningMean; - Tensor runningVariance; - bool is_2d; //bool train = true; }; @@ -39,8 +35,8 @@ public: #ifdef USE_CUDNN class CudnnBatchNormHandle: public BatchNormHandle { -public: - CudnnBatchNormHandle(const float momentum, const Tensor& input, const Tensor& RunningMean, const Tensor& RunningVariance); + public: + CudnnBatchNormHandle(const float momentum, const Tensor& input); //~CudnnBatchNormHandle(); @@ -49,13 +45,17 @@ public: cudnnTensorDescriptor_t param_desc = nullptr; }; -Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, - const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh); +const std::vector<Tensor> GpuBatchNormForwardTraining(const CudnnBatchNormHandle + &cbnh, const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, + Tensor& running_mean, Tensor& running_var); -Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, - const CudnnBatchNormHandle &cbnh); +Tensor GpuBatchNormForwardInference(const CudnnBatchNormHandle &cbnh, + const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, + const Tensor& running_mean, const Tensor& running_var); -std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh); +const std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh, + const Tensor& dy, const Tensor& x, const Tensor& bnScale, const Tensor& mean, + const Tensor& var); #endif // USE_CUDNN
