SINGA-379 Implement batchnorm operation and its related functions for autograd
- fixed some bugs. - modified the design of batchnorm operation - write test file for batchnorm layer(operation), unit test passed. Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/8654f894 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/8654f894 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/8654f894 Branch: refs/heads/master Commit: 8654f8942a73d7bd86a0bf2e4b2a9f154b124d1e Parents: 10274f3 Author: xuewanqi <[email protected]> Authored: Wed Jul 11 06:40:07 2018 +0000 Committer: Wang Wei <[email protected]> Committed: Wed Jul 11 21:57:48 2018 +0800 ---------------------------------------------------------------------- python/singa/autograd.py | 42 ++++++++++++++++++++++++----------- src/api/model_operation.i | 2 +- src/model/operation/batchnorm.cc | 35 +++++++++++++++++------------ src/model/operation/batchnorm.h | 2 +- test/python/test_operation.py | 4 ++-- 5 files changed, 54 insertions(+), 31 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8654f894/python/singa/autograd.py ---------------------------------------------------------------------- diff --git a/python/singa/autograd.py b/python/singa/autograd.py index 97a75b4..4ba0b11 100755 --- a/python/singa/autograd.py +++ b/python/singa/autograd.py @@ -770,35 +770,44 @@ class Conv2D(Layer): y = conv2d(x, self.W, self.b, self.handle) return y + class BatchNorm2d(NewLayer): - def __init__(self, num_features, momentum = 0.9): + + def __init__(self, num_features, momentum=0.9): self.channels = num_features self.momentum = momentum param_shape = (self.channels,) - self.scale = Tensor(shape=param_shape, requires_grad=True, stores_grad=True) + self.scale = Tensor(shape=param_shape, + requires_grad=True, stores_grad=True) self.scale.set_value(1.0) - self.bias = Tensor(shape=param_shape, requires_grad=True, stores_grad=True) + self.bias = Tensor(shape=param_shape, + requires_grad=True, stores_grad=True) self.bias.set_value(0.0) - self.runningmean = Tensor(shape=param_shape, requires_grad=False, stores_grad=False) - self.runningvariance = Tensor(shape=param_shape, requires_grad=False, stores_grad=False) + self.runningmean = Tensor( + shape=param_shape, requires_grad=False, stores_grad=False) + self.runningvariance = Tensor( + shape=param_shape, requires_grad=False, stores_grad=False) def __call__(self, x): assert x.shape[1] == self.channels, 'number of channels dismatched.' - self.device_check(x, self.scale, self.bias, self.runningmean,self.runningvariance) + self.device_check(x, self.scale, self.bias, + self.runningmean, self.runningvariance) if x.device.id() == -1: raise NotImplementedError else: if not hasattr(self, 'handle'): - self.handle = singa.CudnnBatchNormHandle(self.momentum, x.data, self.runningmean.data, self.runningvariance.data) + self.handle = singa.CudnnBatchNormHandle( + self.momentum, x.data, self.runningmean.data, self.runningvariance.data) elif x.shape[0] != self.handle.batchsize: - self.handle = singa.CudnnBatchNormHandle(self.momentum, x.data, self.runningmean.data, self.runningvariance.data) + self.handle = singa.CudnnBatchNormHandle( + self.momentum, x.data, self.runningmean.data, self.runningvariance.data) self.handle.device_id = x.device.id() y = batchnorm2d(x, self.scale, self.bias, self.handle) @@ -806,26 +815,32 @@ class BatchNorm2d(NewLayer): class _BatchNorm2d(Operation): - def __init(self, handle): + + def __init__(self, handle): self.handle = handle def forward(self, x, scale, bias): if training: - self.cache=(x,) + resultmean = CTensor([scale.shape(0)]) + resultvariance = CTensor([scale.shape(0)]) + self.cache = (x, resultmean, resultvariance, scale) + if self.handle.device_id == -1: raise NotImplementedError else: + resultmean.ToDevice(x.device()) + resultvariance.ToDevice(x.device()) return singa.GpuBatchNormForwardTraining(x, scale, bias, self.cache, self.handle) else: if self.handle.device_id == -1: raise NotImplementedError else: - return singa.GpuBatchNormForwardInference(x, scale, bias ,self.handle) + return singa.GpuBatchNormForwardInference(x, scale, bias, self.handle) def backward(self, dy): assert training is True and hasattr( - self, 'cahce'), 'Please set training as True before do BP. ' + self, 'cache'), 'Please set training as True before do BP. ' if dy.device().id() != self.handle.device_id: dy.ToDevice(self.cache[0].device()) @@ -833,7 +848,8 @@ class _BatchNorm2d(Operation): if self.handle.device_id == -1: raise NotImplementedError else: - dx, ds, db = singa.GpuBatchNormBackward(dy, self.cache, self.handle) + dx, ds, db = singa.GpuBatchNormBackward( + dy, self.cache, self.handle) return dx, ds, db http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8654f894/src/api/model_operation.i ---------------------------------------------------------------------- diff --git a/src/api/model_operation.i b/src/api/model_operation.i index 95efd26..a1d59ed 100755 --- a/src/api/model_operation.i +++ b/src/api/model_operation.i @@ -70,7 +70,7 @@ class CudnnBatchNormHandle: public BatchNormHandle{ }; Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, - std::vector<Tensor>& cache, CudnnBatchNormHandle &cbnh); + const std::vector<Tensor>& cache, CudnnBatchNormHandle &cbnh); Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, const CudnnBatchNormHandle &cbnh); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8654f894/src/model/operation/batchnorm.cc ---------------------------------------------------------------------- diff --git a/src/model/operation/batchnorm.cc b/src/model/operation/batchnorm.cc index 145b90b..6b2421d 100755 --- a/src/model/operation/batchnorm.cc +++ b/src/model/operation/batchnorm.cc @@ -8,8 +8,8 @@ BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input, cons batchsize = input.shape(0); channels = input.shape(1); if (input.nDim() == 4u) { - height = input.shape(2); - width = input.shape(3); + height = input.shape().at(2); + width = input.shape().at(3); is_2d = false; } else if (input.nDim() == 2u) { height = 1; @@ -41,7 +41,14 @@ CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum, const Tensor& i }; Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, - std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) { + const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) { + CHECK_EQ(x.device()->lang(), kCuda); + CHECK_EQ(bnScale.device()->lang(), kCuda); + CHECK_EQ(bnBias.device()->lang(), kCuda); + CHECK_EQ(cbnh.runningMean.device()->lang(), kCuda); + CHECK_EQ(cbnh.runningVariance.device()->lang(), kCuda); + CHECK_EQ(cache[1].device()->lang(), kCuda); //resultmean + CHECK_EQ(cache[2].device()->lang(), kCuda); //resultvariance Shape shape = x.shape(); Tensor output; @@ -52,17 +59,6 @@ Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const input = x; output.ResetLike(x); - Tensor resultSaveMean; - Tensor resultSaveVariance; - - resultSaveMean.Reshape(Shape{cbnh.channels}); - resultSaveVariance.Reshape(Shape{cbnh.channels}); - - cache.push_back(resultSaveMean); - cache.push_back(resultSaveVariance); - cache.push_back(bnScale); - //cache={x, mean, var, scale} - output.device()->Exec( [&output, &input, &bnScale, &bnBias, &cache, &cbnh](Context * ctx) { Block* inBlock = input.block(), * outBlock = output.block(), @@ -92,6 +88,12 @@ Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, const CudnnBatchNormHandle &cbnh) { + CHECK_EQ(x.device()->lang(), kCuda); + CHECK_EQ(bnScale.device()->lang(), kCuda); + CHECK_EQ(bnBias.device()->lang(), kCuda); + CHECK_EQ(cbnh.runningMean.device()->lang(), kCuda); + CHECK_EQ(cbnh.runningVariance.device()->lang(), kCuda); + Shape shape = x.shape(); Tensor output; Tensor input; //for unification of 2d and 4d cases. @@ -125,6 +127,11 @@ Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, cons std::vector<Tensor> GpuBatchNormBackward(const Tensor& dy, const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh) { + CHECK_EQ(dy.device()->lang(), kCuda); + CHECK_EQ(cache[0].device()->lang(), kCuda); + CHECK_EQ(cache[1].device()->lang(), kCuda); + CHECK_EQ(cache[2].device()->lang(), kCuda); + CHECK_EQ(cache[3].device()->lang(), kCuda); vector<Tensor> out_grads; Tensor dx; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8654f894/src/model/operation/batchnorm.h ---------------------------------------------------------------------- diff --git a/src/model/operation/batchnorm.h b/src/model/operation/batchnorm.h index f21bd1d..ee182f9 100755 --- a/src/model/operation/batchnorm.h +++ b/src/model/operation/batchnorm.h @@ -50,7 +50,7 @@ public: }; Tensor GpuBatchNormForwardTraining(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, - std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh); + const std::vector<Tensor>& cache, const CudnnBatchNormHandle &cbnh); Tensor GpuBatchNormForwardInference(const Tensor& x, const Tensor& bnScale, const Tensor& bnBias, const CudnnBatchNormHandle &cbnh); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8654f894/test/python/test_operation.py ---------------------------------------------------------------------- diff --git a/test/python/test_operation.py b/test/python/test_operation.py index 0e851d7..67018c1 100755 --- a/test/python/test_operation.py +++ b/test/python/test_operation.py @@ -79,12 +79,12 @@ class TestPythonOperation(unittest.TestCase): dy = CTensor([2, 3, 3, 3]) singa.Gaussian(0.0, 1.0, dy) - y=batchnorm_0(gpu_input_tensor) + y = batchnorm_0(gpu_input_tensor) dx, ds, db = y.creator.backward(dy) self.check_shape(y.shape, (2, 3, 3, 3)) self.check_shape(dx.shape(), (2, 3, 3, 3)) - self.check_shape(dx.shape(), (3,)) + self.check_shape(ds.shape(), (3,)) self.check_shape(db.shape(), (3,)) if __name__ == '__main__':
