SINGA-371 Implement functional operations in c++ for autograd - merge definition of handles and their init functions
- modified conv2d operation in python part Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/aa9c52ae Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/aa9c52ae Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/aa9c52ae Branch: refs/heads/master Commit: aa9c52aeba2e71638c2d8905a8bc37fd8603a510 Parents: e68ea2e Author: xuewanqi <[email protected]> Authored: Sat Jun 30 09:09:30 2018 +0000 Committer: xuewanqi <[email protected]> Committed: Mon Jul 2 06:09:07 2018 +0000 ---------------------------------------------------------------------- examples/autograd/mlp.py | 4 +- examples/autograd/mnist_cnn.py | 11 +- python/singa/autograd.py | 106 +++--- python/singa/tensor.py | 2 +- src/api/model_operation.i | 36 +- src/core/tensor/tensor_math_cpp.h | 44 ++- src/model/operation/convolution_operation.cc | 366 ++++++++++++++++++ src/model/operation/convolution_operation.h | 78 ++++ src/model/operation/convolution_related.cc | 431 ---------------------- src/model/operation/convolution_related.h | 75 ---- 10 files changed, 567 insertions(+), 586 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/aa9c52ae/examples/autograd/mlp.py ---------------------------------------------------------------------- diff --git a/examples/autograd/mlp.py b/examples/autograd/mlp.py index 3910369..f7c4353 100644 --- a/examples/autograd/mlp.py +++ b/examples/autograd/mlp.py @@ -26,6 +26,8 @@ import numpy as np if __name__ == '__main__': + autograd.training = True + # prepare training data in numpy array # generate the boundary @@ -60,7 +62,7 @@ if __name__ == '__main__': label = to_categorical(label, 2).astype(np.float32) print('train_data_shape:', data.shape) print('train_label_shape:', label.shape) - + # 1 inputs = Tensor(data=data) target = Tensor(data=label) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/aa9c52ae/examples/autograd/mnist_cnn.py ---------------------------------------------------------------------- diff --git a/examples/autograd/mnist_cnn.py b/examples/autograd/mnist_cnn.py index 7b72c75..cbb5650 100644 --- a/examples/autograd/mnist_cnn.py +++ b/examples/autograd/mnist_cnn.py @@ -21,10 +21,13 @@ import numpy as np import argparse import os +import singa from singa import tensor from singa import autograd from singa import optimizer +singa.layer.engine = 'singacpp' + def load_data(path): f = np.load(path) @@ -97,8 +100,8 @@ if __name__ == '__main__': print('the shape of testing label is', y_test.shape) # operations initialization - conv1 = autograd.Conv2d(3, 32) - conv2 = autograd.Conv2d(32, 32) + conv1 = autograd.Conv2D(1, 32, 3, padding=1) + conv2 = autograd.Conv2D(32, 32, 3, padding=1) linear = autograd.Linear(32 * 28 * 28, 10) def forward(x, t): @@ -121,8 +124,8 @@ if __name__ == '__main__': loss, y = forward(inputs, targets) - accuracy_rate = accuracy(autograd.ctensor2numpy( - y.data), autograd.ctensor2numpy(targets.data)) + accuracy_rate = accuracy(autograd.ctensor2numpy(y.data), + autograd.ctensor2numpy(targets.data)) if (i % 5 == 0): print('accuracy is:', accuracy_rate, 'loss is:', autograd.ctensor2numpy(loss.data)[0]) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/aa9c52ae/python/singa/autograd.py ---------------------------------------------------------------------- diff --git a/python/singa/autograd.py b/python/singa/autograd.py index e301e51..b1475bb 100644 --- a/python/singa/autograd.py +++ b/python/singa/autograd.py @@ -88,7 +88,7 @@ class Operation(object): ys = (ys,) # create Tensor based on CTensor(data); # assume outputs are all Tensor instances - ys = tuple(Tensor(device=y.device, + ys = tuple(Tensor(device=y.device(), data=y, requires_grad=self.requires_grad, creator=self) for y in ys) @@ -442,7 +442,7 @@ class Conv2d(Operation): param_data = self.PyLayer.layer.param_values() if not hasattr(self, 'w'): - self.w = Tensor(device=param_data[0].device, data=param_data[ + self.w = Tensor(device=param_data[0].device(), data=param_data[ 0], requires_grad=True, stores_grad=True) std = math.sqrt( 2.0 / (self.in_channels * self.kernel_size[0] * self.kernel_size[1] + self.out_channels)) @@ -452,7 +452,7 @@ class Conv2d(Operation): if len(param_data) == 2: if not hasattr(self, 'b'): - self.b = Tensor(device=param_data[1].device, data=param_data[ + self.b = Tensor(device=param_data[1].device(), data=param_data[ 1], requires_grad=True, stores_grad=True) self.b.set_value(0.0) @@ -638,79 +638,75 @@ class Conv2D(Operation): else: #to keep consistency when to do forward. self.b = Tensor(data=CTensor([]), requires_grad=False, stores_grad=False) - - self.reset = False - def __call__(self, x): - assert x.ndim() == 4, 'The dimensions of input should be 4D.' - assert x.shape[1] == self.in_channels, 'in_channels dismatched.' - assert 0 == 0, 'invalid padding.' - # TODO valid padding check. - - if not hasattr (self, 'recorder'): - self.recorder = singa.SetupRecorder(x.data, self.kernel_size, self.stride, - self.padding, self.in_channels, self.out_channels, self.bias) - elif x.shape[0] != self.recorder.batchsize: - self.recorder = singa.SetupRecorder(x.data, self.kernel_size, self.stride, - self.padding, self.in_channels, self.out_channels, self.bias) - self.reset = True - - if training: - self.x = x + def __call__(self, x): + if not hasattr(self, 'device_id'): + self.device_id = x.device.id() + else: + assert self.device_id == x.device.id(),'Not the same device.' - self.dev = x.device + if self.W.device.id() != self.device_id: + self.W.to_device(x.device) - self.W.to_device(self.dev) - xs = [x, self.W] - if self.bias: - self.b.to_device(self.dev) - xs.append(self.b) + if self.b.device.id() != self.device_id: + self.b.to_device(x.device) + + xs = [x, self.W, self.b] + return self._do_forward(*xs)[0] def forward(self, *xs): - if self.dev.lang()==1: #kCuda = 1 - if not hasattr(self, 'cudnnconvhandles'): - self.cudnnconvhandles=singa.InitCudnnConvHandles(xs[0], self.recorder, - self.inner_params['workspace_MB_limit']*1024*1024, self.inner_params['cudnn_prefer']) - elif self.reset: - self.cudnnconvhandles=singa.InitCudnnConvHandles(xs[0], self.recorder, - self.inner_params['workspace_MB_limit']*1024*1024, self.inner_params['cudnn_prefer']) + assert xs[0].nDim() == 4, 'The dimensions of input should be 4D.' + assert xs[0].shape()[1] == self.in_channels, 'in_channels dismatched.' + #assert (xs[0].shape()[2]+2*self.padding[0]-self.kernel_size[0]-1)%self.stride[0] == 0, 'invalid padding.' + assert 0==0, 'invalid padding' - return singa.GpuConvForward(xs[0], xs[1], xs[2], self.recorder, self.cudnnconvhandles) + if training: + self.x = xs[0] - elif self.dev.lang()==0: #kCpp = 0 - return singa.CpuConvForward(xs[0], xs[1], xs[2], self.recorder) + if self.device_id == -1: + if not hasattr (self, 'handles'): + self.handles = singa.ConvHandles(xs[0], self.kernel_size, self.stride, + self.padding, self.in_channels, self.out_channels, self.bias) + elif xs[0].shape()[0] != self.handles.batchsize: + self.handles = singa.ConvHandles(xs[0], self.kernel_size, self.stride, + self.padding, self.in_channels, self.out_channels, self.bias) + return singa.CpuConvForward(xs[0], xs[1], xs[2], self.handles) else: - TypeError('Not implemented yet') - + if not hasattr(self, 'handles'): + self.handles = singa.CudnnConvHandles(xs[0], self.kernel_size, self.stride, + self.padding, self.in_channels, self.out_channels, self.bias, + self.inner_params['workspace_MB_limit']*1024*1024, self.inner_params['cudnn_prefer']) + elif xs[0].shape()[0] != self.handles.batchsize: + self.handles = singa.CudnnConvHandles(xs[0], self.kernel_size, self.stride, + self.padding, self.in_channels, self.out_channels, self.bias, + self.inner_params['workspace_MB_limit']*1024*1024, self.inner_params['cudnn_prefer']) + return singa.GpuConvForward(xs[0], xs[1], xs[2], self.handles) def backward(self, dy): assert training is True and hasattr(self, 'x'), 'Please set training as True before do BP. ' - # todo check device? - dy.ToDevice(self.dev) + if dy.device().id() != self.device_id: + dy.ToDevice(self.x.device()) - if self.dev.lang()==1: #kCuda = 1 - dx = singa.GpuConvBackwardx(dy, self.W.data, self.x.data, self.cudnnconvhandles) - dW = singa.GpuConvBackwardW(dy, self.x.data, self.W.data, self.cudnnconvhandles) + if self.device_id == -1: + dx = singa.CpuConvBackwardx(dy, self.W.data, self.x, self.handles) + dW = singa.CpuConvBackwardW(dy, self.x, self.W.data, self.handles) if self.bias: - db = singa.GpuConvBackwardb(dy, self.b.data, self.cudnnconvhandles) - return dx, dW, db + db = singa.CpuConvBackwardb(dy, self.b.data, self.handles) + return dx, dW, db else: - return dx, dW - - elif self.dev.lang()==0: #kCpp = 0 - dx = singa.CpuConvBackwardx(dy, self.W.data, self.x.data, self.recorder) - dW = singa.CpuConvBackwardW(dy, self.x.data, self.W.data, self.recorder) + return dx, dW + else: + dx = singa.GpuConvBackwardx(dy, self.W.data, self.x, self.handles) + dW = singa.GpuConvBackwardW(dy, self.x, self.W.data, self.handles) if self.bias: - db = singa.CpuConvBackwardb(dy, self.b.data, self.recorder) + db = singa.GpuConvBackwardb(dy, self.b.data, self.handles) return dx, dW, db else: return dx, dW - else: - TypeError('Not implemented yet') def infer_dependency(op): ''' @@ -813,7 +809,7 @@ def backward(y, dy=None): if y_stores_grad: # store the gradient for final return, e.g. if x is parameter g = not_ready[src_op][y_idx] - gradients[y] = Tensor(device=g.device, data=g) + gradients[y] = Tensor(device=g.device(), data=g) dependency[src_op] -= 1 if src_op.requires_grad is True: if dependency[src_op] == 0: http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/aa9c52ae/python/singa/tensor.py ---------------------------------------------------------------------- diff --git a/python/singa/tensor.py b/python/singa/tensor.py index 8f36775..eddce28 100644 --- a/python/singa/tensor.py +++ b/python/singa/tensor.py @@ -98,7 +98,7 @@ class Tensor(object): copy_from_numpy(self.data, data) elif isinstance(data, CTensor): self.data = data - assert data.device == device, 'not the same device' + assert data.device().id() == device.id(), 'not the same device' else: self.data = CTensor(list(shape), device, dtype) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/aa9c52ae/src/api/model_operation.i ---------------------------------------------------------------------- diff --git a/src/api/model_operation.i b/src/api/model_operation.i index 1d31b9d..29f8f58 100644 --- a/src/api/model_operation.i +++ b/src/api/model_operation.i @@ -1,24 +1,32 @@ %module model_operation %{ -#include "../src/model/operation/convolution_related.h" +#include "../src/model/operation/convolution_operation.h" %} namespace singa{ -struct Recorder{size_t batchsize;}; +struct ConvHandles{ -struct CudnnConvHandles{}; + size_t batchsize; + ConvHandles(const Tensor &input, const std::vector<size_t> kernel_size, + const std::vector<size_t> stride, const std::vector<size_t> padding, + const size_t in_channels, const size_t out_channels, + const bool bias_term_); + }; -Recorder SetupRecorder(const Tensor &input, const std::vector<size_t> kernel_size, - const std::vector<size_t> stride, const std::vector<size_t> padding, - const size_t in_channels, const size_t out_channels, - const bool bias_term_); +struct CudnnConvHandles{ -CudnnConvHandles InitCudnnConvHandles(const Tensor &input, const Recorder r, - const size_t workspace_byte_limit_=1024*1024*1024, const std::string prefer_="fastest"); + size_t batchsize; + + CudnnConvHandles(const Tensor &input, const std::vector<size_t> kernel_size, + const std::vector<size_t> stride, const std::vector<size_t> padding, + const size_t in_channels, const size_t out_channels, + const bool bias_term_, const size_t workspace_byte_limit_=1024*1024*1024, + const std::string prefer_="fastest"); + }; -Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const Recorder r, const CudnnConvHandles cch); +Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const CudnnConvHandles cch); Tensor GpuConvBackwardx(const Tensor &dy, const Tensor &W, const Tensor &x, const CudnnConvHandles cch); @@ -27,12 +35,12 @@ Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, cons Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandles cch); -Tensor CpuConvForward(const Tensor &x, Tensor &W, Tensor &b, const Recorder r); +Tensor CpuConvForward(const Tensor &x, Tensor &W, Tensor &b, const ConvHandles ch); -Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const Recorder r); +Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const ConvHandles ch); -Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const Recorder r); +Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const ConvHandles ch); -Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const Recorder r); +Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandles ch); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/aa9c52ae/src/core/tensor/tensor_math_cpp.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h index bfdd026..67f1f20 100644 --- a/src/core/tensor/tensor_math_cpp.h +++ b/src/core/tensor/tensor_math_cpp.h @@ -506,18 +506,52 @@ void Asum<float, lang::Cpp>(const Tensor& in, float *out, *out = cblas_sasum(in.Size(), inPtr, 1); //not using strided traversal } +// template <> +// void Axpy<float, lang::Cpp>(const float alpha, +// const Tensor& in, Tensor *out, Context *ctx) { +// //check input tensor for strides first +// if (in.strides() == out->strides()) { +// const float *inPtr = static_cast<const float *>(in.block()->data()); +// float *outPtr = static_cast<float *>(out->block()->mutable_data()); +// cblas_saxpy(in.Size(), alpha, inPtr, 1, outPtr, 1); +// } else { +// //LOG(FATAL) << "Axpy, input and output strides do not match." ; +// EltwiseMult<float, lang::Cpp>(in, alpha, out, ctx); +// } +// } + template <> void Axpy<float, lang::Cpp>(const float alpha, const Tensor& in, Tensor *out, Context *ctx) { //check input tensor for strides first + const float *inPtr = static_cast<const float *>(in.block()->data()); + float *outPtr = static_cast<float *>(out->block()->mutable_data()); + if (in.strides() == out->strides()) { - const float *inPtr = static_cast<const float *>(in.block()->data()); - float *outPtr = static_cast<float *>(out->block()->mutable_data()); cblas_saxpy(in.Size(), alpha, inPtr, 1, outPtr, 1); } else { - LOG(FATAL) << "Axpy, input and output strides do not match." ; - } -} + //LOG(FATAL) << "Axpy, input and output strides do not match." ; + Tensor t(in.shape(), in.device(), in.data_type()); + EltwiseMult<float, lang::Cpp>(in, alpha, &t, ctx); + float* tPtr = static_cast<float*>(t.block()->mutable_data()); + cblas_saxpy(in.Size(), 1, tPtr, 1, outPtr, 1); + } +} + +// template <> +// void Axpy<float, lang::Cpp>(const float alpha, +// const Tensor& in, Tensor *out, Context *ctx) { +// //check input tensor for strides first +// if (in.strides() == out->strides()) { +// const float *inPtr = static_cast<const float *>(in.block()->data()); +// float *outPtr = static_cast<float *>(out->block()->mutable_data()); +// cblas_saxpy(in.Size(), alpha, inPtr, 1, outPtr, 1); +// } else if(out->transpose()) { +// LOG(FATAL) << "output is already transposed." ; +// } else { +// LOG(FATAL) << "Axpy, input and output strides do not match." ; +// } +// } template <> void Dot<float, lang::Cpp>(const Tensor& in1, const Tensor& in2, http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/aa9c52ae/src/model/operation/convolution_operation.cc ---------------------------------------------------------------------- diff --git a/src/model/operation/convolution_operation.cc b/src/model/operation/convolution_operation.cc new file mode 100644 index 0000000..90b1b4a --- /dev/null +++ b/src/model/operation/convolution_operation.cc @@ -0,0 +1,366 @@ +#include "./convolution_operation.h" +#include "../layer/convolution.h" +#include<iostream> + +namespace singa{ + +ConvHandles::ConvHandles(const Tensor &input, const std::vector<size_t> kernel_size, + const std::vector<size_t> stride, const std::vector<size_t> padding, + const size_t in_channels, const size_t out_channels, + const bool bias_term_){ + kernel_h_=kernel_size[0]; + kernel_w_=kernel_size[1]; + + pad_h_=padding[0]; + pad_w_=padding[1]; + + stride_h_=stride[0]; + stride_w_=stride[1]; + + channels_=in_channels; + num_filters_=out_channels; + + batchsize = input.shape(0); + CHECK(input.shape(1) == in_channels)<<"the number of input channels mismatched."; + height_ = input.shape(2); + width_ = input.shape(3); + + conv_height_ = 1; + if (stride_h_ > 0) + conv_height_ = (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1; + conv_width_ = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1; + + col_height_ = in_channels * kernel_w_ * kernel_h_; + col_width_ = conv_height_ * conv_width_; + imagesize = input.Size() / batchsize; +}; + +CudnnConvHandles::CudnnConvHandles(const Tensor &input, const std::vector<size_t> kernel_size, + const std::vector<size_t> stride, const std::vector<size_t> padding, + const size_t in_channels, const size_t out_channels,const bool bias_term_, + const size_t workspace_byte_limit_,const std::string prefer_) + :ConvHandles(input, kernel_size, stride, padding, in_channels, out_channels, bias_term_){ + + DataType dtype = input.data_type(); + auto dev = input.device(); + Context *ctx = dev->context(0); + + CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_)); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_)); + if (bias_term_) + CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc_)); + CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_)); + CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc_)); + + + CUDNN_CHECK(cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, + GetCudnnDataType(dtype), batchsize, + channels_, height_, width_)); + CUDNN_CHECK(cudnnSetTensor4dDescriptor( + y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize, + num_filters_, conv_height_, conv_width_)); + if (bias_term_) + CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc_, CUDNN_TENSOR_NCHW, + GetCudnnDataType(dtype), 1, + num_filters_, 1, 1)); + CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc_, pad_h_, pad_w_, + stride_h_, stride_w_, 1, 1, + CUDNN_CROSS_CORRELATION, + GetCudnnDataType(dtype))); + CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc_, GetCudnnDataType(dtype), + CUDNN_TENSOR_NCHW, num_filters_, + channels_, kernel_h_, kernel_w_)); + if (prefer_ == "fastest" || prefer_ == "limited_workspace" || + prefer_ == "no_workspace") { + cudnnConvolutionFwdPreference_t fwd_pref; + cudnnConvolutionBwdFilterPreference_t bwd_filt_pref; + cudnnConvolutionBwdDataPreference_t bwd_data_pref; + if (prefer_ == "fastest") { + fwd_pref = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST; + bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST; + bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST; + } else if (prefer_ == "limited_workspace") { + fwd_pref = CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT; + bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT; + bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT; + } else { + fwd_pref = CUDNN_CONVOLUTION_FWD_NO_WORKSPACE; + bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE; + bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT; + } + CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm( + ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref, + workspace_byte_limit_, &fp_alg_)); + CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm( + ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, + bwd_filt_pref, workspace_byte_limit_, &bp_filter_alg_)); + // deprecated in cudnn v7 + CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm( + ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, + bwd_data_pref, workspace_byte_limit_, &bp_data_alg_)); + } else if (prefer_ == "autotune") { + const int topk = 1; + int num_fp_alg, num_bp_filt_alg, num_bp_data_alg; + cudnnConvolutionFwdAlgoPerf_t fp_alg_perf[topk]; + cudnnConvolutionBwdFilterAlgoPerf_t bp_filt_perf[topk]; + cudnnConvolutionBwdDataAlgoPerf_t bp_data_perf[topk]; + CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithm( + ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, topk, + &num_fp_alg, fp_alg_perf)); + fp_alg_ = fp_alg_perf[0].algo; + CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithm( + ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, topk, + &num_bp_filt_alg, bp_filt_perf)); + bp_filter_alg_ = bp_filt_perf[0].algo; + CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithm( + ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, topk, + &num_bp_data_alg, bp_data_perf)); + bp_data_alg_ = bp_data_perf[0].algo; + } else { + LOG(FATAL) << "Preferred algorithm is not available!"; + } + + size_t fp_byte, bp_data_byte, bp_filter_byte; + CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize( + ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fp_alg_, + &fp_byte)); + CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize( + ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, + bp_data_alg_, &bp_data_byte)); + CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize( + ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, + bp_filter_alg_, &bp_filter_byte)); + workspace_count_ = std::max(std::max(fp_byte, bp_data_byte), bp_filter_byte) / + sizeof(float) + + 1; + if (workspace_count_ * sizeof(float) > workspace_byte_limit_) + LOG(WARNING) << "The required memory for workspace (" + << workspace_count_ * sizeof(float) + << ") is larger than the expected Bytes (" + << workspace_byte_limit_ << ")"; + workspace_ = Tensor(Shape{workspace_count_}, dev, dtype); +}; + +Convolution C; + +Tensor CpuConvForward(const Tensor &x, Tensor &W, Tensor &b, const ConvHandles ch){ + CHECK_EQ(x.device()->lang(), kCpp); + + CHECK(x.shape(1) == ch.channels_ && x.shape(2) == ch.height_ && + x.shape(3) == ch.width_) << "input sample shape should not change"; + + CHECK(W.shape(0) == ch.num_filters_ && W.shape(1) == ch.channels_ && + W.shape(2) == ch.kernel_h_ && W.shape(3) == ch.kernel_w_) << "weights shape should not change"; + + Shape w_shape= W.shape(); + Shape b_shape= b.shape(); + + W.Reshape(Shape{ch.num_filters_, ch.col_height_}); + if (ch.bias_term_) + b.Reshape(Shape{ch.num_filters_}); + + DataType dtype = x.data_type(); + auto dev = x.device(); + Shape shape{ch.batchsize, ch.num_filters_, ch.conv_height_, ch.conv_width_}; + Tensor output(shape, dev, dtype); + + Tensor col_data(Shape{ch.col_height_, ch.col_width_});//broadcasted image + + float *data_col = new float[ch.col_height_ * ch.col_width_]; + auto in_data = x.data<float>(); + for (size_t num = 0; num < ch.batchsize; num++) { + C.Im2col(in_data + num * ch.imagesize, ch.channels_, ch.height_, ch.width_, ch.kernel_h_, + ch.kernel_w_, ch.pad_h_, ch.pad_w_, ch.stride_h_, ch.stride_w_, data_col); + + col_data.CopyDataFromHostPtr(data_col, ch.col_height_ * ch.col_width_); + Tensor each = Mult(W, col_data); + if (ch.bias_term_) { + AddColumn(b, &each); + } + CopyDataToFrom(&output, each, each.Size(), num * each.Size()); + }; + W.Reshape(w_shape); + b.Reshape(b_shape); + return output; +}; + +Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const ConvHandles ch){ + CHECK_EQ(dy.device()->lang(), kCpp); + + CHECK(dy.shape(1) == ch.num_filters_ && dy.shape(2) == ch.conv_height_ && + dy.shape(3) == ch.conv_width_) << "input gradients shape should not change"; + + CHECK(W.shape(0) == ch.num_filters_ && W.shape(1) == ch.channels_ && + W.shape(2) == ch.kernel_h_ && W.shape(3) == ch.kernel_w_) << "weights shape should not change"; + + Shape w_shape= W.shape(); + W.Reshape(Shape{ch.num_filters_, ch.col_height_}); + + Tensor dx; + dx.ResetLike(x); + + float *dx_b = new float[ch.imagesize]; + + for (size_t num = 0; num < ch.batchsize; num++) { + Tensor grad_b(Shape{ch.num_filters_, ch.conv_height_ * ch.conv_width_}); + CopyDataToFrom(&grad_b, dy, grad_b.Size(), 0, num * grad_b.Size()); + Tensor dcol_b = Mult(W.T(), grad_b); + auto dcol_data = dcol_b.data<float>(); + C.Col2im(dcol_data, ch.channels_, ch.height_, ch.width_, ch.kernel_h_, ch.kernel_w_, ch.pad_h_, + ch.pad_w_, ch.stride_h_, ch.stride_w_, dx_b); + dx.CopyDataFromHostPtr(dx_b, ch.imagesize, num * ch.imagesize); + } + W.Reshape(w_shape); + return dx; +}; + +Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const ConvHandles ch){ + CHECK_EQ(dy.device()->lang(), kCpp); + + CHECK(dy.shape(1) == ch.num_filters_ && dy.shape(2) == ch.conv_height_ && + dy.shape(3) == ch.conv_width_) << "input gradients shape should not change"; + + CHECK(x.shape(1) == ch.channels_ && x.shape(2) == ch.height_ && + x.shape(3) == ch.width_) << "input sample shape should not change"; + + Tensor dW; + dW.ResetLike(W); + dW.SetValue(0.0f); + + Shape w_shape= W.shape(); + dW.Reshape(Shape{ch.num_filters_, ch.col_height_}); + + Tensor col_data(Shape{ch.col_height_, ch.col_width_});//broadcasted image + + float *data_col = new float[ch.col_height_ * ch.col_width_]; + auto in_data = dy.data<float>(); + for (size_t num = 0; num < ch.batchsize; num++) { + C.Im2col(in_data + num * ch.imagesize, ch.channels_, ch.height_, ch.width_, ch.kernel_h_, + ch.kernel_w_, ch.pad_h_, ch.pad_w_, ch.stride_h_, ch.stride_w_, data_col); + col_data.CopyDataFromHostPtr(data_col, ch.col_height_ * ch.col_width_); + Tensor grad_b(Shape{ch.num_filters_, ch.conv_height_ * ch.conv_width_}); + CopyDataToFrom(&grad_b, dy, grad_b.Size(), 0, num * grad_b.Size()); + dW += Mult(grad_b, col_data.T()); + } + dW.Reshape(w_shape); + return dW; +}; + +Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandles ch){ + CHECK_EQ(dy.device()->lang(), kCpp); + + CHECK(dy.shape(1) == ch.num_filters_ && dy.shape(2) == ch.conv_height_ && + dy.shape(3) == ch.conv_width_) << "input gradients shape should not change"; + + CHECK(b.shape(0) == ch.num_filters_)<< "bias shape should not change"; + + Tensor db; + db.ResetLike(b); + + auto tmpshp = Shape{ch.batchsize * ch.num_filters_, dy.Size() / (ch.batchsize * ch.num_filters_)}; + Tensor tmp1 = Reshape(dy, tmpshp); + + Tensor tmp2(Shape{ch.batchsize * ch.num_filters_}); + SumColumns(tmp1, &tmp2); + Tensor tmp3 = Reshape(tmp2, Shape{ch.batchsize, ch.num_filters_}); + + SumRows(tmp3, &db); + + return db; +}; + +Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const CudnnConvHandles cch){ + CHECK_EQ(x.device()->lang(), kCuda); + + DataType dtype = x.data_type(); + auto dev = x.device(); + + Shape shape{cch.batchsize, cch.num_filters_, cch.conv_height_, cch.conv_width_}; + Tensor output(shape, dev, dtype); + + output.device()->Exec([output, x, W, cch](Context *ctx) { + Block *inblock = x.block(), *outblock = output.block(), + *wblock = W.block(); + float alpha = 1.f, beta = 0.f; + cudnnConvolutionForward(ctx->cudnn_handle, &alpha, cch.x_desc_, + inblock->data(), cch.filter_desc_, wblock->data(), + cch.conv_desc_, cch.fp_alg_, + cch.workspace_.block()->mutable_data(), + cch.workspace_count_ * sizeof(float), &beta, + cch.y_desc_, outblock->mutable_data()); + }, {x.block(), W.block()}, {output.block()}, cch.workspace_.block()); + + if (cch.bias_term_) { + output.device()->Exec([output, b, cch](Context *ctx) { + float beta = 1.f, alpha = 1.0f; + Block *outblock = output.block(), *bblock = b.block(); + cudnnAddTensor(ctx->cudnn_handle, &alpha, cch.bias_desc_, + bblock->data(), &beta, cch.y_desc_, + outblock->mutable_data()); + }, {output.block(), b.block()}, {output.block()}); + } + + return output; +}; + +Tensor GpuConvBackwardx(const Tensor &dy, const Tensor &W, const Tensor &x, const CudnnConvHandles cch){ + CHECK_EQ(dy.device()->lang(), kCuda); + + Tensor dx; + dx.ResetLike(x); + + dy.device()->Exec([dx, dy, W, cch](Context *ctx) { + Block *wblock = W.block(), *dyblock = dy.block(), + *dxblock = dx.block(); + float alpha = 1.f, beta = 0.f; + cudnnConvolutionBackwardData(ctx->cudnn_handle, &alpha, cch.filter_desc_, + wblock->data(), cch.y_desc_, dyblock->data(), + cch.conv_desc_, cch.bp_data_alg_, + cch.workspace_.block()->mutable_data(), + cch.workspace_count_ * sizeof(float), &beta, + cch.x_desc_, dxblock->mutable_data()); + }, {dy.block(), W.block()}, {dx.block(), cch.workspace_.block()}); + + return dx; +}; + +Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const CudnnConvHandles cch){ + CHECK_EQ(dy.device()->lang(), kCuda); + + Tensor dW; + dW.ResetLike(W); + + dy.device()->Exec([dW, dy, x, W, cch](Context *ctx) { + Block *inblock = x.block(), *dyblock = dy.block(), + *dwblock = dW.block(); + float alpha = 1.f, beta = 0.f; + cudnnConvolutionBackwardFilter( + ctx->cudnn_handle, &alpha, cch.x_desc_, inblock->data(), + cch.y_desc_, dyblock->data(), cch.conv_desc_, cch.bp_filter_alg_, + cch.workspace_.block()->mutable_data(), + cch.workspace_count_ * sizeof(float), &beta, cch.filter_desc_, + dwblock->mutable_data()); + }, {dy.block(), x.block()}, {dW.block(), cch.workspace_.block()}); + + return dW; +}; + +// input Tensor b for Reset db purpose, can avoid this later. +Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandles cch){ + CHECK_EQ(dy.device()->lang(), kCuda); + + Tensor db; + db.ResetLike(b); + + dy.device()->Exec([db, dy, b, cch](Context *ctx) { + Block *dyblock = dy.block(), *dbblock = db.block(); + float alpha = 1.f, beta = 0.f; + cudnnConvolutionBackwardBias(ctx->cudnn_handle, &alpha, cch.y_desc_, + dyblock->data(), &beta, cch.bias_desc_, + dbblock->mutable_data()); + }, {dy.block()}, {db.block()}); + + return db; +}; + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/aa9c52ae/src/model/operation/convolution_operation.h ---------------------------------------------------------------------- diff --git a/src/model/operation/convolution_operation.h b/src/model/operation/convolution_operation.h new file mode 100644 index 0000000..835581e --- /dev/null +++ b/src/model/operation/convolution_operation.h @@ -0,0 +1,78 @@ +#include <string> +#include <vector> +#include <cudnn.h> +#include "../layer/cudnn_convolution.h" +#include "../layer/cudnn_utils.h" +#include "singa/utils/logging.h" + +namespace singa{ + +struct ConvHandles{ + size_t kernel_w_; + size_t pad_w_; + size_t stride_w_; + size_t kernel_h_; + size_t pad_h_; + size_t stride_h_; + + size_t channels_; + size_t num_filters_; + + bool bias_term_; + + size_t height_; + size_t width_; + size_t conv_height_; + size_t conv_width_; + size_t batchsize; + + size_t col_height_; + size_t col_width_; + size_t imagesize; + + ConvHandles(const Tensor &input, const std::vector<size_t> kernel_size, + const std::vector<size_t> stride, const std::vector<size_t> padding, + const size_t in_channels, const size_t out_channels, + const bool bias_term_); + +}; + +struct CudnnConvHandles:ConvHandles{ + cudnnTensorDescriptor_t x_desc_ ; + cudnnTensorDescriptor_t y_desc_ ; + cudnnTensorDescriptor_t bias_desc_ ; + cudnnFilterDescriptor_t filter_desc_ ; + cudnnConvolutionDescriptor_t conv_desc_ ; + cudnnConvolutionFwdAlgo_t fp_alg_; + cudnnConvolutionBwdFilterAlgo_t bp_filter_alg_; + cudnnConvolutionBwdDataAlgo_t bp_data_alg_; + + size_t workspace_count_; + Tensor workspace_; + + CudnnConvHandles(const Tensor &input, const std::vector<size_t> kernel_size, + const std::vector<size_t> stride, const std::vector<size_t> padding, + const size_t in_channels, const size_t out_channels, + const bool bias_term_, const size_t workspace_byte_limit_=1024*1024*1024, + const std::string prefer_="fastest"); +}; + +Tensor CpuConvForward(const Tensor &x, Tensor &W, Tensor &b, const ConvHandles ch); + +Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const ConvHandles ch); + +Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const ConvHandles ch); + +Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const ConvHandles ch); + + +Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const CudnnConvHandles cch); + +Tensor GpuConvBackwardx(const Tensor &dy, const Tensor &W, const Tensor &x, const CudnnConvHandles cch); + +Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const CudnnConvHandles cch); + +Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandles cch); + + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/aa9c52ae/src/model/operation/convolution_related.cc ---------------------------------------------------------------------- diff --git a/src/model/operation/convolution_related.cc b/src/model/operation/convolution_related.cc deleted file mode 100644 index c828f90..0000000 --- a/src/model/operation/convolution_related.cc +++ /dev/null @@ -1,431 +0,0 @@ -#include "./convolution_related.h" -#include "../layer/convolution.h" -#include<iostream> - -namespace singa{ - -Recorder SetupRecorder(const Tensor &input, const std::vector<size_t> kernel_size, - const std::vector<size_t> stride, const std::vector<size_t> padding, - const size_t in_channels, const size_t out_channels, - const bool bias_term_){ - size_t kernel_w_; - size_t pad_w_; - size_t stride_w_; - size_t kernel_h_; - size_t pad_h_; - size_t stride_h_; - - size_t height_; - size_t width_; - size_t conv_height_; - size_t conv_width_; - size_t batchsize; - - size_t col_height_; - size_t col_width_; - size_t imagesize; - - kernel_h_=kernel_size[0]; - kernel_w_=kernel_size[1]; - - pad_h_=padding[0]; - pad_w_=padding[1]; - - stride_h_=stride[0]; - stride_w_=stride[1]; - - batchsize = input.shape(0); - CHECK(input.shape(1) == in_channels)<<"the number of input channels mismatched."; - height_ = input.shape(2); - width_ = input.shape(3); - - conv_height_ = 1; - if (stride_h_ > 0) - conv_height_ = (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1; - conv_width_ = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1; - - col_height_ = in_channels * kernel_w_ * kernel_h_; - col_width_ = conv_height_ * conv_width_; - imagesize = input.Size() / batchsize; - - return Recorder{ - kernel_w_, - pad_w_, - stride_w_, - kernel_h_, - pad_h_, - stride_h_, - - in_channels, - out_channels, - - bias_term_, - - height_, - width_, - conv_height_, - conv_width_, - batchsize, - - col_height_, - col_width_, - imagesize - }; -}; - -Convolution C; - -Tensor CpuConvForward(const Tensor &x, Tensor &W, Tensor &b, const Recorder r){ - CHECK_EQ(x.device()->lang(), kCpp); - - CHECK(x.shape(1) == r.channels_ && x.shape(2) == r.height_ && - x.shape(3) == r.width_) << "input sample shape should not change"; - - CHECK(W.shape(0) == r.num_filters_ && W.shape(1) == r.channels_ && - W.shape(2) == r.kernel_h_ && W.shape(3) == r.kernel_w_) << "weights shape should not change"; - - Shape w_shape= W.shape(); - Shape b_shape= b.shape(); - - W.Reshape(Shape{r.num_filters_, r.col_height_}); - if (r.bias_term_) - b.Reshape(Shape{r.num_filters_}); - - DataType dtype = x.data_type(); - auto dev = x.device(); - Shape shape{r.batchsize, r.num_filters_, r.conv_height_, r.conv_width_}; - Tensor output(shape, dev, dtype); - - Tensor col_data(Shape{r.col_height_, r.col_width_});//broadcasted image - - float *data_col = new float[r.col_height_ * r.col_width_]; - auto in_data = x.data<float>(); - for (size_t num = 0; num < r.batchsize; num++) { - C.Im2col(in_data + num * r.imagesize, r.channels_, r.height_, r.width_, r.kernel_h_, - r.kernel_w_, r.pad_h_, r.pad_w_, r.stride_h_, r.stride_w_, data_col); - - col_data.CopyDataFromHostPtr(data_col, r.col_height_ * r.col_width_); - Tensor each = Mult(W, col_data); - if (r.bias_term_) { - AddColumn(b, &each); - } - CopyDataToFrom(&output, each, each.Size(), num * each.Size()); - }; - W.Reshape(w_shape); - b.Reshape(b_shape); - return output; -}; - -Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const Recorder r){ - CHECK_EQ(dy.device()->lang(), kCpp); - - CHECK(dy.shape(1) == r.num_filters_ && dy.shape(2) == r.conv_height_ && - dy.shape(3) == r.conv_width_) << "input gradients shape should not change"; - - CHECK(W.shape(0) == r.num_filters_ && W.shape(1) == r.channels_ && - W.shape(2) == r.kernel_h_ && W.shape(3) == r.kernel_w_) << "weights shape should not change"; - - Shape w_shape= W.shape(); - W.Reshape(Shape{r.num_filters_, r.col_height_}); - - Tensor dx; - dx.ResetLike(x); - - float *dx_b = new float[r.imagesize]; - - for (size_t num = 0; num < r.batchsize; num++) { - Tensor grad_b(Shape{r.num_filters_, r.conv_height_ * r.conv_width_}); - CopyDataToFrom(&grad_b, dy, grad_b.Size(), 0, num * grad_b.Size()); - Tensor dcol_b = Mult(W.T(), grad_b); - auto dcol_data = dcol_b.data<float>(); - C.Col2im(dcol_data, r.channels_, r.height_, r.width_, r.kernel_h_, r.kernel_w_, r.pad_h_, - r.pad_w_, r.stride_h_, r.stride_w_, dx_b); - dx.CopyDataFromHostPtr(dx_b, r.imagesize, num * r.imagesize); - } - W.Reshape(w_shape); - return dx; -}; - -Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const Recorder r){ - CHECK_EQ(dy.device()->lang(), kCpp); - - CHECK(dy.shape(1) == r.num_filters_ && dy.shape(2) == r.conv_height_ && - dy.shape(3) == r.conv_width_) << "input gradients shape should not change"; - - CHECK(x.shape(1) == r.channels_ && x.shape(2) == r.height_ && - x.shape(3) == r.width_) << "input sample shape should not change"; - - Tensor dW; - dW.ResetLike(W); - dW.SetValue(0.0f); - - Shape w_shape= W.shape(); - dW.Reshape(Shape{r.num_filters_, r.col_height_}); - - Tensor col_data(Shape{r.col_height_, r.col_width_});//broadcasted image - - float *data_col = new float[r.col_height_ * r.col_width_]; - auto in_data = dy.data<float>(); - for (size_t num = 0; num < r.batchsize; num++) { - C.Im2col(in_data + num * r.imagesize, r.channels_, r.height_, r.width_, r.kernel_h_, - r.kernel_w_, r.pad_h_, r.pad_w_, r.stride_h_, r.stride_w_, data_col); - col_data.CopyDataFromHostPtr(data_col, r.col_height_ * r.col_width_); - Tensor grad_b(Shape{r.num_filters_, r.conv_height_ * r.conv_width_}); - CopyDataToFrom(&grad_b, dy, grad_b.Size(), 0, num * grad_b.Size()); - dW += Mult(grad_b, col_data.T()); - } - dW.Reshape(w_shape); - return dW; -}; - -Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const Recorder r){ - CHECK_EQ(dy.device()->lang(), kCpp); - - CHECK(dy.shape(1) == r.num_filters_ && dy.shape(2) == r.conv_height_ && - dy.shape(3) == r.conv_width_) << "input gradients shape should not change"; - - CHECK(b.shape(0) == r.num_filters_)<< "bias shape should not change"; - - Tensor db; - db.ResetLike(b); - - auto tmpshp = Shape{r.batchsize * r.num_filters_, dy.Size() / (r.batchsize * r.num_filters_)}; - Tensor tmp1 = Reshape(dy, tmpshp); - - Tensor tmp2(Shape{r.batchsize * r.num_filters_}); - SumColumns(tmp1, &tmp2); - Tensor tmp3 = Reshape(tmp2, Shape{r.batchsize, r.num_filters_}); - - SumRows(tmp3, &db); - - return db; -}; - -CudnnConvHandles InitCudnnConvHandles(const Tensor &input, const Recorder r, const size_t workspace_byte_limit_, - const std::string prefer_){ - - CHECK(input.shape(0) == r.batchsize && input.shape(1) == r.channels_ && input.shape(2) == r.height_ && - input.shape(3) == r.width_) << "input sample shape dismatched"; - - cudnnTensorDescriptor_t x_desc_ ; - cudnnTensorDescriptor_t y_desc_ ; - cudnnTensorDescriptor_t bias_desc_ ; - cudnnFilterDescriptor_t filter_desc_ ; - cudnnConvolutionDescriptor_t conv_desc_ ; - cudnnConvolutionFwdAlgo_t fp_alg_; - cudnnConvolutionBwdFilterAlgo_t bp_filter_alg_; - cudnnConvolutionBwdDataAlgo_t bp_data_alg_; - - size_t workspace_count_; - Tensor workspace_; - - DataType dtype = input.data_type(); - auto dev = input.device(); - Context *ctx = dev->context(0); - - CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_)); - CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_)); - if (r.bias_term_) - CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc_)); - CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_)); - CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc_)); - - - CUDNN_CHECK(cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, - GetCudnnDataType(dtype), r.batchsize, - r.channels_, r.height_, r.width_)); - CUDNN_CHECK(cudnnSetTensor4dDescriptor( - y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), r.batchsize, - r.num_filters_, r.conv_height_, r.conv_width_)); - if (r.bias_term_) - CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc_, CUDNN_TENSOR_NCHW, - GetCudnnDataType(dtype), 1, - r.num_filters_, 1, 1)); - CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc_, r.pad_h_, r.pad_w_, - r.stride_h_, r.stride_w_, 1, 1, - CUDNN_CROSS_CORRELATION, - GetCudnnDataType(dtype))); - CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc_, GetCudnnDataType(dtype), - CUDNN_TENSOR_NCHW, r.num_filters_, - r.channels_, r.kernel_h_, r.kernel_w_)); - if (prefer_ == "fastest" || prefer_ == "limited_workspace" || - prefer_ == "no_workspace") { - cudnnConvolutionFwdPreference_t fwd_pref; - cudnnConvolutionBwdFilterPreference_t bwd_filt_pref; - cudnnConvolutionBwdDataPreference_t bwd_data_pref; - if (prefer_ == "fastest") { - fwd_pref = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST; - bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST; - bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST; - } else if (prefer_ == "limited_workspace") { - fwd_pref = CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT; - bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT; - bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT; - } else { - fwd_pref = CUDNN_CONVOLUTION_FWD_NO_WORKSPACE; - bwd_filt_pref = CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE; - bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT; - } - CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm( - ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref, - workspace_byte_limit_, &fp_alg_)); - CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm( - ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, - bwd_filt_pref, workspace_byte_limit_, &bp_filter_alg_)); - // deprecated in cudnn v7 - CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm( - ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, - bwd_data_pref, workspace_byte_limit_, &bp_data_alg_)); - } else if (prefer_ == "autotune") { - const int topk = 1; - int num_fp_alg, num_bp_filt_alg, num_bp_data_alg; - cudnnConvolutionFwdAlgoPerf_t fp_alg_perf[topk]; - cudnnConvolutionBwdFilterAlgoPerf_t bp_filt_perf[topk]; - cudnnConvolutionBwdDataAlgoPerf_t bp_data_perf[topk]; - CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithm( - ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, topk, - &num_fp_alg, fp_alg_perf)); - fp_alg_ = fp_alg_perf[0].algo; - CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithm( - ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, topk, - &num_bp_filt_alg, bp_filt_perf)); - bp_filter_alg_ = bp_filt_perf[0].algo; - CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithm( - ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, topk, - &num_bp_data_alg, bp_data_perf)); - bp_data_alg_ = bp_data_perf[0].algo; - } else { - LOG(FATAL) << "Preferred algorithm is not available!"; - } - - size_t fp_byte, bp_data_byte, bp_filter_byte; - CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize( - ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fp_alg_, - &fp_byte)); - CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize( - ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, - bp_data_alg_, &bp_data_byte)); - CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize( - ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, - bp_filter_alg_, &bp_filter_byte)); - workspace_count_ = std::max(std::max(fp_byte, bp_data_byte), bp_filter_byte) / - sizeof(float) + - 1; - if (workspace_count_ * sizeof(float) > workspace_byte_limit_) - LOG(WARNING) << "The required memory for workspace (" - << workspace_count_ * sizeof(float) - << ") is larger than the expected Bytes (" - << workspace_byte_limit_ << ")"; - workspace_ = Tensor(Shape{workspace_count_}, dev, dtype); - - return CudnnConvHandles{ - x_desc_, - y_desc_, - bias_desc_, - filter_desc_, - conv_desc_, - fp_alg_, - bp_filter_alg_, - bp_data_alg_, - - workspace_count_, - workspace_, - }; - -}; - -Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const Recorder r, const CudnnConvHandles cch){ - CHECK_EQ(x.device()->lang(), kCuda); - - DataType dtype = x.data_type(); - auto dev = x.device(); - - Shape shape{r.batchsize, r.num_filters_, r.conv_height_, r.conv_width_}; - Tensor output(shape, dev, dtype); - - output.device()->Exec([output, x, W, cch](Context *ctx) { - Block *inblock = x.block(), *outblock = output.block(), - *wblock = W.block(); - float alpha = 1.f, beta = 0.f; - cudnnConvolutionForward(ctx->cudnn_handle, &alpha, cch.x_desc_, - inblock->data(), cch.filter_desc_, wblock->data(), - cch.conv_desc_, cch.fp_alg_, - cch.workspace_.block()->mutable_data(), - cch.workspace_count_ * sizeof(float), &beta, - cch.y_desc_, outblock->mutable_data()); - }, {x.block(), W.block()}, {output.block()}, cch.workspace_.block()); - - if (r.bias_term_) { - output.device()->Exec([output, b, cch](Context *ctx) { - float beta = 1.f, alpha = 1.0f; - Block *outblock = output.block(), *bblock = b.block(); - cudnnAddTensor(ctx->cudnn_handle, &alpha, cch.bias_desc_, - bblock->data(), &beta, cch.y_desc_, - outblock->mutable_data()); - }, {output.block(), b.block()}, {output.block()}); - } - - return output; -}; - -Tensor GpuConvBackwardx(const Tensor &dy, const Tensor &W, const Tensor &x, const CudnnConvHandles cch){ - CHECK_EQ(dy.device()->lang(), kCuda); - - Tensor dx; - dx.ResetLike(x); - - dy.device()->Exec([dx, dy, W, cch](Context *ctx) { - Block *wblock = W.block(), *dyblock = dy.block(), - *dxblock = dx.block(); - float alpha = 1.f, beta = 0.f; - cudnnConvolutionBackwardData(ctx->cudnn_handle, &alpha, cch.filter_desc_, - wblock->data(), cch.y_desc_, dyblock->data(), - cch.conv_desc_, cch.bp_data_alg_, - cch.workspace_.block()->mutable_data(), - cch.workspace_count_ * sizeof(float), &beta, - cch.x_desc_, dxblock->mutable_data()); - }, {dy.block(), W.block()}, {dx.block(), cch.workspace_.block()}); - - return dx; -}; - -Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const CudnnConvHandles cch){ - CHECK_EQ(dy.device()->lang(), kCuda); - - Tensor dW; - dW.ResetLike(W); - - dy.device()->Exec([dW, dy, x, W, cch](Context *ctx) { - Block *inblock = x.block(), *dyblock = dy.block(), - *dwblock = dW.block(); - float alpha = 1.f, beta = 0.f; - cudnnConvolutionBackwardFilter( - ctx->cudnn_handle, &alpha, cch.x_desc_, inblock->data(), - cch.y_desc_, dyblock->data(), cch.conv_desc_, cch.bp_filter_alg_, - cch.workspace_.block()->mutable_data(), - cch.workspace_count_ * sizeof(float), &beta, cch.filter_desc_, - dwblock->mutable_data()); - }, {dy.block(), x.block()}, {dW.block(), cch.workspace_.block()}); - - return dW; -}; - -// input Tensor b for Reset db purpose, can avoid this later. -Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandles cch){ - CHECK_EQ(dy.device()->lang(), kCuda); - - Tensor db; - db.ResetLike(b); - - dy.device()->Exec([db, dy, b, cch](Context *ctx) { - Block *dyblock = dy.block(), *dbblock = db.block(); - float alpha = 1.f, beta = 0.f; - cudnnConvolutionBackwardBias(ctx->cudnn_handle, &alpha, cch.y_desc_, - dyblock->data(), &beta, cch.bias_desc_, - dbblock->mutable_data()); - }, {dy.block()}, {db.block()}); - - return db; -}; - -} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/aa9c52ae/src/model/operation/convolution_related.h ---------------------------------------------------------------------- diff --git a/src/model/operation/convolution_related.h b/src/model/operation/convolution_related.h deleted file mode 100644 index 49aab5b..0000000 --- a/src/model/operation/convolution_related.h +++ /dev/null @@ -1,75 +0,0 @@ -#include <string> -#include <vector> -#include <cudnn.h> -#include "../layer/cudnn_convolution.h" -#include "../layer/cudnn_utils.h" -#include "singa/utils/logging.h" - -namespace singa{ - -struct Recorder{ - size_t kernel_w_; - size_t pad_w_; - size_t stride_w_; - size_t kernel_h_; - size_t pad_h_; - size_t stride_h_; - - size_t channels_; - size_t num_filters_; - - bool bias_term_; - - size_t height_; - size_t width_; - size_t conv_height_; - size_t conv_width_; - size_t batchsize; - - size_t col_height_; - size_t col_width_; - size_t imagesize; -}; - -struct CudnnConvHandles{ - cudnnTensorDescriptor_t x_desc_ ; - cudnnTensorDescriptor_t y_desc_ ; - cudnnTensorDescriptor_t bias_desc_ ; - cudnnFilterDescriptor_t filter_desc_ ; - cudnnConvolutionDescriptor_t conv_desc_ ; - cudnnConvolutionFwdAlgo_t fp_alg_; - cudnnConvolutionBwdFilterAlgo_t bp_filter_alg_; - cudnnConvolutionBwdDataAlgo_t bp_data_alg_; - - size_t workspace_count_; - Tensor workspace_; -}; - - -Recorder SetupRecorder(const Tensor &input, const std::vector<size_t> kernel_size, - const std::vector<size_t> stride, const std::vector<size_t> padding, - const size_t in_channels, const size_t out_channels, - const bool bias_term_); - -CudnnConvHandles InitCudnnConvHandles(const Tensor &input, const Recorder r, const size_t workspace_byte_limit_=1024*1024*1024, - const std::string prefer_="fastest"); - -Tensor GpuConvForward(const Tensor &x, const Tensor &W, const Tensor &b, const Recorder r, const CudnnConvHandles cch); - -Tensor GpuConvBackwardx(const Tensor &dy, const Tensor &W, const Tensor &x, const CudnnConvHandles cch); - -Tensor GpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const CudnnConvHandles cch); - -Tensor GpuConvBackwardb(const Tensor &dy, const Tensor &b, const CudnnConvHandles cch); - - - -Tensor CpuConvForward(const Tensor &x, Tensor &W, Tensor &b, const Recorder r); - -Tensor CpuConvBackwardx(const Tensor &dy, Tensor &W, const Tensor &x, const Recorder r); - -Tensor CpuConvBackwardW(const Tensor &dy, const Tensor &x, const Tensor &W, const Recorder r); - -Tensor CpuConvBackwardb(const Tensor &dy, const Tensor &b, const Recorder r); - -} \ No newline at end of file
