Repository: incubator-singa Updated Branches: refs/heads/master f134a24e2 -> a36291824
SINGA-378 Implement maxpooling operation and its related functions for autograd - implement corresponding functions for maxpooling, GPU part. - write inferface file for maxpooling functions. - implement maxpooling layer and maxpooling operation in python - modified example codes. Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/571818eb Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/571818eb Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/571818eb Branch: refs/heads/master Commit: 571818eb06ca1e88e51a1ad8adb61f43c349ee62 Parents: f134a24 Author: xuewanqi <xue_wa...@outlook.com> Authored: Wed Jul 11 14:30:02 2018 +0000 Committer: Wang Wei <wangwei...@gmail.com> Committed: Thu Jul 12 17:08:21 2018 +0800 ---------------------------------------------------------------------- examples/autograd/mnist_cnn.py | 23 +++-- python/singa/autograd.py | 162 ++++++++++++++++++++---------------- src/api/model_operation.i | 38 ++++++++- src/model/operation/pooling.cc | 126 ++++++++++++++++++++++++++++ src/model/operation/pooling.h | 63 ++++++++++++++ 5 files changed, 327 insertions(+), 85 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/571818eb/examples/autograd/mnist_cnn.py ---------------------------------------------------------------------- diff --git a/examples/autograd/mnist_cnn.py b/examples/autograd/mnist_cnn.py index b1d8dbe..2cb3cae 100755 --- a/examples/autograd/mnist_cnn.py +++ b/examples/autograd/mnist_cnn.py @@ -84,7 +84,7 @@ if __name__ == '__main__': dev = device.get_default_device() else: print('Using GPU') - dev = device.create_cuda_gpu_on(1) + dev = device.create_cuda_gpu() train, test = load_data(args.file_path) @@ -92,7 +92,7 @@ if __name__ == '__main__': num_classes = 10 epochs = 1 - sgd = optimizer.SGD(0.01) + sgd = optimizer.SGD(0.001) x_train = preprocess(train[0]) y_train = to_categorical(train[1], num_classes) @@ -110,27 +110,32 @@ if __name__ == '__main__': conv2 = autograd.Conv2D(32, 32, 3, padding=1) bn2 = autograd.BatchNorm(32) linear = autograd.Linear(32 * 28 * 28, 10) - + pooling1 = autograd.MaxPool2D(3, 1, padding=1) + pooling2 = autograd.MaxPool2D(3, 1, padding=1) def forward(x, t): y = conv1(x) y = autograd.relu(y) y = bn1(y) y = autograd.max_pool_2d(y) + y = pooling1(y) + y = conv2(y) - y = bn2(y) y = autograd.relu(y) - y = autograd.max_pool_2d(y) - y=autograd.flatten(y) + y = bn2(y) + y = pooling2(y) + y = autograd.flatten(y) y = linear(y) loss = autograd.softmax_cross_entropy(y, t) return loss, y autograd.training = True - for epoch in range(epochs): + for epoch in range(50): for i in range(batch_number): - inputs = tensor.Tensor(device=dev, data=x_train[ i * 100:(1 + i) * 100], stores_grad=False) - targets = tensor.Tensor(device=dev, data=y_train[i * 100:(1 + i) * 100], requires_grad=False, stores_grad=False) + inputs = tensor.Tensor(device=dev, data=x_train[ + i * 100:(1 + i) * 100], stores_grad=False) + targets = tensor.Tensor(device=dev, data=y_train[ + i * 100:(1 + i) * 100], requires_grad=False, stores_grad=False) loss, y = forward(inputs, targets) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/571818eb/python/singa/autograd.py ---------------------------------------------------------------------- diff --git a/python/singa/autograd.py b/python/singa/autograd.py index d272dcd..fcdc020 100755 --- a/python/singa/autograd.py +++ b/python/singa/autograd.py @@ -37,12 +37,9 @@ def infer_dependency(op): ''' Infer the dependency of all operations with the given op as the last operation. - Operation A is depending on B is A uses the output(s) of B. - Args: op: an Operation instance, e.g. the loss operation. - Return: a Counter instance with the operation as the key, and the number of operations that are depending on it as the value @@ -74,12 +71,10 @@ def gradients(y, dy=None): def backward(y, dy=None): ''' Run the backward propagation starting at y. - Args: y: a Tensor instance, usually the loss dy: a number or a Tensor instance, for the gradient of the objective/loss w.r.t y, usually 1.0 - Return: a dictionary storing the gradient tensors of all tensors whose stores_grad is true (e.g. parameter tensors) @@ -156,7 +151,6 @@ class Operation(object): ''' An operation includes the forward and backward function of tensor calculation. - Steps to add a specific operation Xxxx: 1. create a subclass of Operation, name it as Xxxx 2. override the forward() and backward(); The arguments of forward() @@ -169,10 +163,8 @@ class Operation(object): def _do_forward(self, *xs): ''' Do not call this function from user code. It is called by __call__(). - Args: xs, Tensor instance(s) - Returns: Tensor instance(s) ''' @@ -218,10 +210,8 @@ class Operation(object): def forward(self, *xs): '''Forward propagation. - Args: xs: input args consisting of only CTensors. - Returns: CTensor instance(s) ''' @@ -229,10 +219,8 @@ class Operation(object): def backward(self, *dys): ''' Backward propagation. - Args: dys: input args consisting of only CTensors. - Returns: CTensor instance(s) ''' @@ -244,7 +232,6 @@ class Operation(object): class Dummy(Operation): '''Dummy operation whice serves as a placehoder for autograd - Args: name(string): set it for debug ''' @@ -262,7 +249,6 @@ class ReLU(Operation): ''' Args: x(CTensor): input tensor - Returns: a new CTensor whose element y = x if x >= 0; otherwise 0; ''' @@ -274,7 +260,6 @@ class ReLU(Operation): ''' Args: dy(CTensor): dL / dy - Returns: dx(CTensor): dL / dx = dy if x >= 0; otherwise 0; ''' @@ -291,13 +276,10 @@ class Matmul(Operation): def forward(self, x, w): '''Do forward propgation. - Store the x(or w) if w(or x) requires gradient. - Args: x (CTensor): matrix w (CTensor): matrix - Returns: a CTensor for the result ''' @@ -309,7 +291,6 @@ class Matmul(Operation): ''' Args: dy (CTensor): data for the dL / dy, L is the loss - Returns: a tuple for (dx, dw) ''' @@ -329,7 +310,6 @@ class AddBias(Operation): def __init__(self, axis=0): ''' To indicate the calculation axis, 0 for row, 1 for column. - Args: axis: 0 or 1, default is 0. ''' @@ -340,7 +320,6 @@ class AddBias(Operation): Args: x: matrix. b: bias to be added. - Return: the result Tensor ''' @@ -354,7 +333,6 @@ class AddBias(Operation): ''' Args: dy (CTensor): data for the dL / dy, L is the loss. - Return: a tuple for (db, dx), db is data for dL / db, dx is data for dL / dx. @@ -382,7 +360,6 @@ class SoftMax(Operation): ''' Args: x(data): the input 1d or 2d tensor - Returns: the result Tensor ''' @@ -398,7 +375,6 @@ class SoftMax(Operation): ''' Args: dy (CTensor): data for the dL / dy, L is the loss - Returns: dx (Ctensor): data for the dL / dx, L is the loss, x is the input of current Opertion @@ -435,7 +411,6 @@ def soft_max(x, axis=0): class CrossEntropy(Operation): ''' Calculte negative log likelihood loss for a batch of training data. - ''' def forward(self, x, t): @@ -444,7 +419,6 @@ class CrossEntropy(Operation): x (CTensor): 1d or 2d tensor, the prediction data(output) of current network. t (CTensor): 1d or 2d tensor, the target data for training. - Returns: loss (CTensor): scalar. ''' @@ -461,7 +435,6 @@ class CrossEntropy(Operation): Args: dy (float or CTensor): scalar, accumulate gradient from outside of current network, usually equal to 1.0 - Returns: dx (CTensor): data for the dL /dx, L is the loss, x is the output of current network. note that this is true for @@ -510,60 +483,33 @@ def ctensor2numpy(x): return np_array.reshape(x.shape()) -class MaxPool2d(Operation): - - def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1, - return_indices=False, ceil_mode=False, **kwargs): - - inner_params = {'name': 'MaxPool2d', - 'border_mode': 'same', - 'data_format': 'NCHW', - 'input_sample_shape': None - } +class _MaxPool2D(Operation): - for kwarg in kwargs: - if kwarg not in inner_params: - raise TypeError('Keyword argument not understood:', kwarg) - else: - inner_params[kwarg] = kwargs[kwarg] + def __init__(self, handle): + self.handle = handle - if padding == 0: - pad = None + def forward(self, x): + if self.handle.device_id == -1: + raise NotImplementedError else: - pad = padding - - if dilation != 1 or return_indices or ceil_mode: - raise ValueError('Not implemented yet') - - self.PyLayer = layer.Pooling2D(inner_params['name'], - model_pb2.PoolingConf.MAX, - kernel_size, stride, inner_params[ - 'border_mode'], - pad, inner_params['data_format'], - inner_params['input_sample_shape']) + y = singa.GpuPoolingForward(x, self.handle) - def __call__(self, x): if training: - self.flag = model_pb2.kTrain - else: - self.flag = model_pb2.kEval - - if not self.PyLayer.has_setup: - self.PyLayer.setup(x.shape[1:]) + self.cache = (x, y) - return self._do_forward(x) - - def forward(self, *xs): - return self.PyLayer.layer.Forward(self.flag, xs[0]) + return y def backward(self, dy): - return self.PyLayer.layer.Backward(0, dy)[0] + if self.handle.device_id == -1: + raise NotImplementedError + else: + dx = singa.GpuPoolingBackward( + dy, self.cache[0], self.cache[1], self.handle) + return dx -def max_pool_2d(x, kernel_size=3, stride=1, padding=0, dilation=1, - return_indices=False, ceil_mode=False, **kwargs): - return MaxPool2d(kernel_size, stride, padding, dilation, return_indices, - ceil_mode, **kwargs)(x)[0] +def max_pool_2d(x, handle): + return _MaxPool2D(handle)(x)[0] class Flatten(Operation): @@ -771,6 +717,9 @@ class Conv2D(Layer): return y +<< << << < HEAD + + class BatchNorm(Layer): def __init__(self, num_features, momentum=0.9): @@ -811,7 +760,7 @@ class BatchNorm(Layer): self.handle.device_id = x.device.id() y = batchnorm(x, self.scale, self.bias, - self.running_mean, self.running_var, self.handle) + self.running_mean, self.running_var, self.handle) return y @@ -857,3 +806,72 @@ class _BatchNorm(Operation): def batchnorm(x, scale, bias, running_mean, running_var, handle): return _BatchNorm(running_mean, running_var, handle)(x, scale, bias)[0] + + +class MaxPool2D(Layer): + + def __init__(self, kernel_size, stride=None, padding=0, dilation=1, + return_indices=False, ceil_mode=False): + if isinstance(kernel_size, int): + self.kernel_size = (kernel_size, kernel_size) + elif isinstance(kernel_size, tuple): + self.kernel_size = kernel_size + else: + raise TypeError('Wrong kernel_size type.') + + if stride is None: + self.stride = self.kernel_size + elif isinstance(stride, int): + self.stride = (stride, stride) + elif isinstance(stride, tuple): + self.stride = stride + else: + raise TypeError('Wrong stride type.') + + if isinstance(padding, int): + self.padding = (padding, padding) + elif isinstance(padding, tuple): + self.padding = padding + else: + raise TypeError('Wrong padding type.') + + if dilation != 1: + raise ValueError('Not implemented yet') + + if return_indices is not False: + raise ValueError('Not implemented yet') + + self.ceil_mode = ceil_mode + + def __call__(self, x): + if self.ceil_mode: + out_shape_h = int(math.ceil( + (x.shape[2] + 2 * self.padding[0] - self.kernel_size[0]) / self.stride[0])) + 1 + out_shape_w = int(math.ceil( + (x.shape[3] + 2 * self.padding[1] - self.kernel_size[1]) / self.stride[1])) + 1 + else: + out_shape_h = int( + (x.shape[2] + 2 * self.padding[0] - self.kernel_size[0]) // self.stride[0]) + 1 + out_shape_w = int( + (x.shape[3] + 2 * self.padding[1] - self.kernel_size[1]) // self.stride[1]) + 1 + if x.device.id() == -1: + if not hasattr(self, 'handle'): + self.handle = singa.PoolingHandle(x.data, self.kernel_size, self.stride, + self.padding, self.ceil_mode, 'MAX') + elif x.shape[0] != self.handle.batchsize or out_shape_h != self.handle.pooled_height or \ + out_shape_w != self.handle.pooled_width: + self.handle = singa.PoolingHandle(x.data, self.kernel_size, self.stride, + self.padding, self.ceil_mode, 'MAX') + else: + if not hasattr(self, 'handle'): + self.handle = singa.CudnnPoolingHandle(x.data, self.kernel_size, self.stride, + self.padding, self.ceil_mode, 'MAX', False) # False for nan_prop + elif x.shape[0] != self.handle.batchsize or out_shape_h != self.handle.pooled_height or \ + out_shape_w != self.handle.pooled_width: + self.handle = singa.CudnnPoolingHandle(x.data, self.kernel_size, self.stride, + self.padding, self.ceil_mode, 'MAX', False) # False for nan_prop + + self.handle.device_id = x.device.id() + + y = max_pool_2d(x, self.handle) + return y http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/571818eb/src/api/model_operation.i ---------------------------------------------------------------------- diff --git a/src/api/model_operation.i b/src/api/model_operation.i index eb41fd0..4800ff1 100755 --- a/src/api/model_operation.i +++ b/src/api/model_operation.i @@ -6,6 +6,8 @@ %{ #include "../src/model/operation/convolution.h" #include "../src/model/operation/batchnorm.h" +#include "../src/model/operation/pooling.h" + %} namespace singa { @@ -29,7 +31,6 @@ Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandle &ch); - class BatchNormHandle{ public: BatchNormHandle(const float momentum, const Tensor& input); @@ -38,6 +39,18 @@ class BatchNormHandle{ }; +class PoolingHandle { + public: + PoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size, + const std::vector<size_t>& stride, const std::vector<size_t>& padding, + const bool ceil_mode = false, const std::string pooling_method = "MAX"); + + size_t batchsize; + + size_t pooled_height; + size_t pooled_width; +}; + #if USE_CUDNN class CudnnConvHandle: public ConvHandle { @@ -60,8 +73,6 @@ Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandle &cch); - - class CudnnBatchNormHandle: public BatchNormHandle{ public: CudnnBatchNormHandle(const float momentum, const Tensor& input); @@ -78,6 +89,25 @@ Tensor GpuBatchNormForwardInference(const CudnnBatchNormHandle &cbnh, const Tens const std::vector<Tensor> GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh, const Tensor& dy, const Tensor& x, const Tensor& bnScale, const Tensor& mean, const Tensor& var); + +class CudnnPoolingHandle : public PoolingHandle { + public: + CudnnPoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size, + const std::vector<size_t>& stride, const std::vector<size_t>& padding, + const bool ceil_mode = false, const std::string pooling_method = "MAX", + const bool NaN_prop = false); + + size_t batchsize; + + size_t pooled_height; + size_t pooled_width; +}; + +Tensor GpuPoolingForward(const Tensor &x, const CudnnPoolingHandle &cph); + +Tensor GpuPoolingBackward(const Tensor &dy, const Tensor& x, const Tensor& y, + const CudnnPoolingHandle &cph); + #endif // USE_CUDNN -} +} //namespace singa \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/571818eb/src/model/operation/pooling.cc ---------------------------------------------------------------------- diff --git a/src/model/operation/pooling.cc b/src/model/operation/pooling.cc new file mode 100644 index 0000000..0abda35 --- /dev/null +++ b/src/model/operation/pooling.cc @@ -0,0 +1,126 @@ +#include "./pooling.h" +#include <cmath> + +namespace singa { + +PoolingHandle::PoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size, + const std::vector<size_t>& stride, const std::vector<size_t>& padding, + const bool ceil_mode, const std::string pooling_method) { + kernel_h = kernel_size[0]; + kernel_w = kernel_size[1]; + + pad_h = padding[0]; + pad_w = padding[1]; + + stride_h = stride[0]; + stride_w = stride[1]; + + batchsize = input.shape(0); + channels = input.shape(1); + height = input.shape(2); + width = input.shape(3); + + pooled_height = 1; + if (ceil_mode) { + if (stride_h > 0) + pooled_height = static_cast<int>(ceil(static_cast<float>(height + 2 * pad_h - kernel_h) / stride_h)) + 1; + pooled_width = static_cast<int>(ceil(static_cast<float>(width + 2 * pad_w - kernel_w) / stride_w)) + 1; + } + else { + if (stride_h > 0) + pooled_height = + static_cast<size_t>((height + 2 * pad_h - kernel_h) / stride_h) + 1; + pooled_width = + static_cast<size_t>((width + 2 * pad_w - kernel_w) / stride_w) + 1; + } + + method = pooling_method; + CHECK(method == "MAX" || method == "AVERAGE") + << "Padding implemented only for average and max pooling."; +} + +#ifdef USE_CUDNN + +CudnnPoolingHandle::CudnnPoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size, + const std::vector<size_t>& stride, const std::vector<size_t>& padding, + const bool ceil_mode, const std::string pooling_method, const bool NaN_prop) + : PoolingHandle(input, kernel_size, stride, padding, ceil_mode, pooling_method) { + if (NaN_prop) + nan_prop = CUDNN_PROPAGATE_NAN; + else + nan_prop = CUDNN_NOT_PROPAGATE_NAN; + + DataType dtype = input.data_type(); + + CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc)); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc)); + CUDNN_CHECK(cudnnCreatePoolingDescriptor(&pool_desc)); + + + CUDNN_CHECK(cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW, + GetCudnnDataType(dtype), batchsize, + channels, height, width)); + CUDNN_CHECK(cudnnSetTensor4dDescriptor( + y_desc, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize, channels, + pooled_height, pooled_width)); + auto pool_method = CUDNN_POOLING_MAX; + if (method == "MAX") + pool_method = CUDNN_POOLING_MAX; + else if (method == "AVERAGE") + pool_method = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; + else + LOG(FATAL) << "Not implemented!"; + + CUDNN_CHECK(cudnnSetPooling2dDescriptor(pool_desc, pool_method, nan_prop, + kernel_h, kernel_w, pad_h, pad_w, + stride_h, stride_w)); +}; + +CudnnPoolingHandle::~CudnnPoolingHandle() { + if (pool_desc != nullptr) + CUDNN_CHECK(cudnnDestroyPoolingDescriptor(pool_desc)); + if (x_desc != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_desc)); + if (y_desc != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_desc)); +}; + +Tensor GpuPoolingForward(const Tensor &x, const CudnnPoolingHandle &cph) { + CHECK_EQ(x.device()->lang(), kCuda); + CHECK_EQ(x.nDim(), 4u); + + DataType dtype = x.data_type(); + auto dev = x.device(); + Shape shape{cph.batchsize, cph.channels, cph.pooled_height, cph.pooled_width}; + Tensor output = Tensor(shape, dev, dtype); + + output.device()->Exec([&x, &output, &cph](Context * ctx) { + Block *inblock = x.block(), *outblock = output.block(); + float alpha = 1.0f, beta = 0.0f; + cudnnPoolingForward(ctx->cudnn_handle, cph.pool_desc, &alpha, + cph.x_desc, inblock->data(), &beta, cph.y_desc, + outblock->mutable_data()); + }, {x.block()}, {output.block()}); + return output; +}; + +Tensor GpuPoolingBackward(const Tensor &dy, const Tensor& x, const Tensor& y, + const CudnnPoolingHandle &cph) { + CHECK_EQ(dy.device()->lang(), kCuda); + CHECK_EQ(dy.nDim(), 4u); + + Tensor dx; + dx.ResetLike(x); + + dx.device()->Exec([&dx, &dy, &x, &y, &cph](Context * ctx) { + Block *dyblock = dy.block(), *dxblock = dx.block(), *yblock = y.block(), + *xblock = x.block(); + float alpha = 1.0f, beta = 0.0f; + cudnnPoolingBackward(ctx->cudnn_handle, cph.pool_desc, &alpha, + cph.y_desc, yblock->data(), cph.y_desc, + dyblock->data(), cph.x_desc, xblock->data(), &beta, + cph.x_desc, dxblock->mutable_data()); + }, {dy.block(), y.block(), x.block()}, {dx.block()}); + return dx; +}; +#endif //USE_CUDNN + +} //namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/571818eb/src/model/operation/pooling.h ---------------------------------------------------------------------- diff --git a/src/model/operation/pooling.h b/src/model/operation/pooling.h new file mode 100644 index 0000000..9ed7e33 --- /dev/null +++ b/src/model/operation/pooling.h @@ -0,0 +1,63 @@ +#ifndef SINGA_MODEL_OPERATION_POOLING_H_ +#define SINGA_MODEL_OPERATION_POOLING_H_ + +#include <string> +#include "singa/core/tensor.h" + +#ifdef USE_CUDNN +#include <cudnn.h> +#include "../layer/cudnn_utils.h" +#endif + +namespace singa { + +class PoolingHandle { +public: + PoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size, + const std::vector<size_t>& stride, const std::vector<size_t>& padding, + const bool ceil_mode = false, const std::string pooling_method = "MAX"); + + size_t kernel_w; + size_t pad_w; + size_t stride_w; + size_t kernel_h; + size_t pad_h; + size_t stride_h; + + size_t batchsize; + size_t channels; + size_t height; + size_t width; + + size_t pooled_height; + size_t pooled_width; + + std::string method; +}; + +#ifdef USE_CUDNN +class CudnnPoolingHandle : public PoolingHandle { +public: + CudnnPoolingHandle(const Tensor &input, const std::vector<size_t>& kernel_size, + const std::vector<size_t>& stride, const std::vector<size_t>& padding, + const bool ceil_mode = false, const std::string pooling_method = "MAX", + const bool NaN_prop = false); + ~CudnnPoolingHandle(); + + cudnnTensorDescriptor_t x_desc = nullptr; + cudnnTensorDescriptor_t y_desc = nullptr; + cudnnPoolingDescriptor_t pool_desc = nullptr; + cudnnNanPropagation_t nan_prop; + +}; + +Tensor GpuPoolingForward(const Tensor &x, const CudnnPoolingHandle &cph); + +Tensor GpuPoolingBackward(const Tensor &dy, const Tensor& x, const Tensor& y, + const CudnnPoolingHandle &cph); + +#endif //USE_CUDNN + +} // namespace singa + +#endif // SINGA_MODEL_OPERATION_POOLING_H_ \ No newline at end of file