This is an automated email from the ASF dual-hosted git repository. wangwei pushed a commit to branch dev in repository https://gitbox.apache.org/repos/asf/singa.git
The following commit(s) were added to refs/heads/dev by this push: new 497a4fc SINGA-505 SoftMax Backward to be bufferable new 8bf0c62 Merge pull request #588 from chrishkchris/SINGA-505 497a4fc is described below commit 497a4fc86fd50ccaf6545b7ed9784b92ce55847e Author: chrishkchris <chrishkch...@yahoo.com.hk> AuthorDate: Tue Feb 11 14:39:32 2020 +0000 SINGA-505 SoftMax Backward to be bufferable --- include/singa/core/tensor.h | 1 + python/singa/autograd.py | 104 +++++++----------- src/api/core_tensor.i | 1 + src/core/tensor/tensor.cc | 219 ++++++++++++++++++++----------------- src/core/tensor/tensor_math.h | 6 +- src/core/tensor/tensor_math_cpp.h | 63 +++++------ src/core/tensor/tensor_math_cuda.h | 42 ++++++- test/python/test_api.py | 21 ++-- 8 files changed, 243 insertions(+), 214 deletions(-) diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h index 93cf44a..846c14c 100644 --- a/include/singa/core/tensor.h +++ b/include/singa/core/tensor.h @@ -514,6 +514,7 @@ void MultRow(const Tensor &v, Tensor *M); /// Do softmax for each row. 'in' could be a 1-d or 2-d Tensor. Tensor SoftMax(const Tensor &in); Tensor SoftMax(const Tensor &in, int axis); +Tensor SoftMaxBackward(const Tensor &in, int axis, const Tensor &fdout); Tensor RowMax(const Tensor &in); /// Do softmax for each row. 'in' could be a 1-d or 2-d Tensor. diff --git a/python/singa/autograd.py b/python/singa/autograd.py index 0c5f456..01e4d82 100644 --- a/python/singa/autograd.py +++ b/python/singa/autograd.py @@ -647,10 +647,14 @@ class Reshape(Operation): self._shape = x.shape() shape = self.shape # handle the shape with 0 - shape = [self._shape[i] if i < len(self._shape) and shape[i] == 0 else shape[i] for i in range(len(shape))] + shape = [ + self._shape[i] + if i < len(self._shape) and shape[i] == 0 else shape[i] + for i in range(len(shape)) + ] # handle the shape with -1 hidden_shape = int(np.prod(self._shape) // np.abs(np.prod(shape))) - self.cache=[s if s != -1 else hidden_shape for s in shape] + self.cache = [s if s != -1 else hidden_shape for s in shape] return singa.Reshape(x, self.cache) @@ -881,32 +885,10 @@ class SoftMax(Operation): dx (Ctensor): data for the dL / dx, L is the loss, x is the input of current Opertion """ - # calculations are made on numpy array - if self.axis == 1: - dy = singa.DefaultTranspose(dy) - grad = ctensor2numpy(dy) - output = ctensor2numpy(self.output) - out_1 = np.einsum("ki,ki->ki", grad, output) - medium_out = np.einsum("ki,kj->kij", output, output) - out_2 = np.einsum("kij,kj->ki", medium_out, grad) - out = out_1 - out_2 - dx = CTensor(out_1.shape) - dx.CopyFloatDataFromHostPtr(out.flatten()) - """grad = Tensor(data=dy) - output = Tensor(data=self.output) - out_1 = einsum('ki,ki->ki', grad, output) - medium_out = einsum('ki,kj->kij', output, output) - out_2 = einsum('kij,kj->ki', medium_out, grad) - out = out_1 - out_2 - dx = CTensor(out_1.data.shape) - dx.CopyFloatDataFromHostPtr(out.data.flatten())""" - if self.axis == 0: - return dx - elif self.axis == 1: - return singa.DefaultTranspose(dx) + return singa.SoftMaxBackward(dy, self.axis, self.output) -def softmax(x, axis=0): +def softmax(x, axis=1): return SoftMax(axis)(x)[0] @@ -1236,16 +1218,13 @@ class _Conv2d(Operation): def backward(self, dy): assert training is True and hasattr( - self, "inputs" - ), "Please set training as True before do BP. " - + self, "inputs"), "Please set training as True before do BP. " + if (type(self.handle) != singa.ConvHandle): - dx = singa.GpuConvBackwardx( - dy, self.inputs[1], self.inputs[0], self.handle - ) - dW = singa.GpuConvBackwardW( - dy, self.inputs[0], self.inputs[1], self.handle - ) + dx = singa.GpuConvBackwardx(dy, self.inputs[1], self.inputs[0], + self.handle) + dW = singa.GpuConvBackwardW(dy, self.inputs[0], self.inputs[1], + self.handle) if self.handle.bias_term: db = singa.GpuConvBackwardb(dy, self.inputs[2], self.handle) return dx, dW, db @@ -1420,13 +1399,13 @@ class Conv2d(Layer): class SeparableConv2d(Layer): def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - bias=False, + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + bias=False, ): self.depthwise_conv = Conv2d( in_channels, @@ -1600,9 +1579,8 @@ class _Pooling2d(Operation): def backward(self, dy): if (type(self.handle) != singa.PoolingHandle): - dx = singa.GpuPoolingBackward( - self.handle, dy, self.cache[0], self.cache[1] - ) + dx = singa.GpuPoolingBackward(self.handle, dy, self.cache[0], + self.cache[1]) else: dx = singa.CpuPoolingBackward(self.handle, dy, self.cache[0], self.cache[1]) @@ -2120,15 +2098,15 @@ class RNN_Base(Layer): class RNN(RNN_Base): def __init__( - self, - input_size, - hidden_size, - num_layers=1, - nonlinearity="tanh", - bias=True, - batch_first=False, - dropout=0, - bidirectional=False, + self, + input_size, + hidden_size, + num_layers=1, + nonlinearity="tanh", + bias=True, + batch_first=False, + dropout=0, + bidirectional=False, ): self.nonlinearity = nonlinearity @@ -2181,15 +2159,15 @@ class RNN(RNN_Base): class LSTM(RNN_Base): def __init__( - self, - input_size, - hidden_size, - nonlinearity="tanh", - num_layers=1, - bias=True, - batch_first=False, - dropout=0, - bidirectional=False, + self, + input_size, + hidden_size, + nonlinearity="tanh", + num_layers=1, + bias=True, + batch_first=False, + dropout=0, + bidirectional=False, ): self.nonlinearity = nonlinearity diff --git a/src/api/core_tensor.i b/src/api/core_tensor.i index d54beed..4550e6a 100755 --- a/src/api/core_tensor.i +++ b/src/api/core_tensor.i @@ -201,6 +201,7 @@ namespace singa{ Tensor Average(const Tensor &t, int axis); Tensor SoftMax(const Tensor &t); Tensor SoftMax(const Tensor &t, int axis); + Tensor SoftMaxBackward(const Tensor &t, int axis, const Tensor &fdout); Tensor Pow(const Tensor &base, const Tensor &exp); diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index c61d4fa..8b90932 100644 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -627,13 +627,11 @@ void RepeatDataToFrom(bool broadcast_flag, const vector<size_t> &repeats, float Tensor::l1() const { float nrm = 0.0f; TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, { - device_->Exec( - [&nrm, this](Context *ctx) { - DType ret = DType(0); - Asum<DType, Lang>(*this, &ret, ctx); - nrm = TypeCast<DType, float>(ret); - }, - {this->block()}, {}); + device_->Exec([&nrm, this](Context *ctx) { + DType ret = DType(0); + Asum<DType, Lang>(*this, &ret, ctx); + nrm = TypeCast<DType, float>(ret); + }, {this->block()}, {}); }); return nrm / Size(); } @@ -645,13 +643,11 @@ float Tensor::L1() const { return l1(); } float Tensor::l2() const { float nrm = 0.0f; TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, { - device_->Exec( - [&nrm, this](Context *ctx) { - DType ret = DType(0); - Nrm2<DType, Lang>(*this, &ret, ctx); - nrm = TypeCast<DType, float>(ret); - }, - {this->block()}, {}); + device_->Exec([&nrm, this](Context *ctx) { + DType ret = DType(0); + Nrm2<DType, Lang>(*this, &ret, ctx); + nrm = TypeCast<DType, float>(ret); + }, {this->block()}, {}); }); return nrm / Size(); } @@ -667,9 +663,9 @@ void Tensor::SetValue(const SType x) { TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, { // TODO(wangwei) cast x to DType - device_->Exec( - [this, x, ptr](Context *ctx) { Set<DType, Lang>(x, this, ctx); }, {}, - {ptr}); + device_->Exec([this, x, ptr](Context *ctx) { + Set<DType, Lang>(x, this, ctx); + }, {}, {ptr}); }); } template void Tensor::SetValue<float>(const float x); @@ -698,9 +694,9 @@ template void Tensor::GetValue<int>(int *value, const size_t num); #define EltwiseUnaryTensorFn(fn, t, ret) \ do { \ TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, { \ - ret->device()->Exec( \ - [t, ret](Context *ctx) { fn<DType, Lang>(t, ret, ctx); }, \ - {t.block()}, {ret->block()}); \ + ret->device()->Exec([t, ret](Context *ctx) { \ + fn<DType, Lang>(t, ret, ctx); \ + }, {t.block()}, {ret->block()}); \ }); \ } while (0) @@ -778,16 +774,55 @@ Tensor SoftMax(const Tensor &in, int axis) { SoftMax(in, retptr, axis); return ret; } +void SoftMaxBackward(const Tensor &in, Tensor *out, int axis, + const Tensor &fdout) { + // {a_0, a_1, ..., a_k-1, a_k, ... a_n-1} + // reshape to + // { a_0 * a_1 * ... a_k-1, a_k * ... a_n-1 } + + // assert axis \in {-r, r-1} + CHECK_LE(axis, (int)in.shape().size() - 1); + CHECK_GE(axis, -1 * (int)in.nDim()); + + Shape original_shape = in.shape(); + if (axis < 0) axis = in.shape().size() + axis; + + Shape coerced_shape = {1, 1}; + for (std::size_t i = 0, max = in.shape().size(); i != max; ++i) { + if (i < axis) + coerced_shape[0] *= in.shape()[i]; + else + coerced_shape[1] *= in.shape()[i]; + } + + Tensor in_reshaped = Reshape(in, coerced_shape); + out->Reshape(coerced_shape); + + do { + TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, { + out->device()->Exec([in, out, fdout](Context *ctx) { + SoftMaxBackward<DType, Lang>(in, out, fdout, ctx); + }, {in.block(), fdout.block()}, {out->block()}); + }); + } while (0); + + out->Reshape(original_shape); +} + +Tensor SoftMaxBackward(const Tensor &in, int axis, const Tensor &fdout) { + Tensor ret(in.shape(), in.device(), in.data_type()); + auto *retptr = &ret; + SoftMaxBackward(in, retptr, axis, fdout); + return ret; +} #define EltwiseBinaryTensorFn(fn, lhs, rhs, ret) \ do { \ TYPE_LANG_SWITCH(lhs.data_type(), DType, lhs.device()->lang(), Lang, { \ CHECK_EQ(sizeof(DType), SizeOf(rhs.data_type())); \ - ret->device()->Exec( \ - [lhs, rhs, ret](Context *ctx) { \ - fn<DType, Lang>(lhs, rhs, ret, ctx); \ - }, \ - {lhs.block(), rhs.block()}, {ret->block()}); \ + ret->device()->Exec([lhs, rhs, ret](Context *ctx) { \ + fn<DType, Lang>(lhs, rhs, ret, ctx); \ + }, {lhs.block(), rhs.block()}, {ret->block()}); \ }); \ } while (0) @@ -832,15 +867,15 @@ GenBinaryTensorFn(operator>, GT); GenBinaryTensorFn(operator>=, GE); GenBinaryTensorFn(ReLUBackward, ReLUBackward); -#define EltwiseTensorScalarFn(fn, t, x, ret) \ - do { \ - TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, { \ - static_assert(std::is_same<SType, DType>::value, \ - "The Scalar type must match the Tensor data type"); \ - ret->device()->Exec( \ - [t, x, ret](Context *ctx) { fn<DType, Lang>(t, x, ret, ctx); }, \ - {t.block()}, {ret->block()}); \ - }); \ +#define EltwiseTensorScalarFn(fn, t, x, ret) \ + do { \ + TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, { \ + static_assert(std::is_same<SType, DType>::value, \ + "The Scalar type must match the Tensor data type"); \ + ret->device()->Exec([t, x, ret](Context *ctx) { \ + fn<DType, Lang>(t, x, ret, ctx); \ + }, {t.block()}, {ret->block()}); \ + }); \ } while (0) #define GenTensorScalarFn(op, fn) \ @@ -880,11 +915,9 @@ void Div(const SType alpha, const Tensor &in, Tensor *out) { CHECK(in.shape() == out->shape()); TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, { // TODO(wangwei) type cast SType to DType; - in.device()->Exec( - [alpha, in, out](Context *ctx) { - Div<DType, Lang>(alpha, in, out, ctx); - }, - {in.block()}, {out->block()}); + in.device()->Exec([alpha, in, out](Context *ctx) { + Div<DType, Lang>(alpha, in, out, ctx); + }, {in.block()}, {out->block()}); }); } template void Div<float>(const float, const Tensor &, Tensor *); @@ -919,13 +952,11 @@ float Sum<float>(const Tensor &in) { Tensor one(in.shape(), in.device(), in.data_type()); one.SetValue(1.0f); TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, { - one.device()->Exec( - [in, one, &s](Context *ctx) { - DType ret = DType(0); - Dot<DType, Lang>(in, one, &ret, ctx); - s = ret; - }, - {in.block(), one.block()}, {}); + one.device()->Exec([in, one, &s](Context *ctx) { + DType ret = DType(0); + Dot<DType, Lang>(in, one, &ret, ctx); + s = ret; + }, {in.block(), one.block()}, {}); }); return s; } @@ -950,24 +981,22 @@ Tensor SumAll(const Tensor &in) { auto *outPtr = &out; one.SetValue(1.0f); TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, { - one.device()->Exec([in, one, outPtr](Context * ctx) { + one.device()->Exec([in, one, outPtr](Context *ctx) { Dot<DType, Lang>(in, one, outPtr, ctx); }, {in.block(), one.block()}, {outPtr->block()}); }); return out; } - + Tensor RowMax(const Tensor &in) { Tensor ret({in.shape(0)}, in.device(), in.data_type()); TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, { - in.device()->Exec( - [&in, &ret](Context *ctx) { - // size_t nrow = 1; - // if (in.nDim() > 1) nrow = in.shape(0); - // size_t ncol = in.Size() / nrow; - RowMax<DType, Lang>(in, &ret, ctx); - }, - {in.block()}, {ret.block()}); + in.device()->Exec([&in, &ret](Context *ctx) { + // size_t nrow = 1; + // if (in.nDim() > 1) nrow = in.shape(0); + // size_t ncol = in.Size() / nrow; + RowMax<DType, Lang>(in, &ret, ctx); + }, {in.block()}, {ret.block()}); }); return ret; } @@ -1179,9 +1208,9 @@ void MultColumn(const Tensor &v, Tensor *M) { CHECK_EQ(v.Size(), M->shape(0)); CheckDataTypeAndLang(*M, v); TYPE_LANG_SWITCH(v.data_type(), DType, v.device()->lang(), Lang, { - v.device()->Exec( - [M, v](Context *ctx) { DGMM<DType, Lang>(false, *M, v, M, ctx); }, - {M->block(), v.block()}, {M->block()}); + v.device()->Exec([M, v](Context *ctx) { + DGMM<DType, Lang>(false, *M, v, M, ctx); + }, {M->block(), v.block()}, {M->block()}); }); } @@ -1193,9 +1222,9 @@ void MultRow(const Tensor &v, Tensor *M) { CHECK_EQ(v.Size(), M->shape(1)); CheckDataTypeAndLang(*M, v); TYPE_LANG_SWITCH(v.data_type(), DType, v.device()->lang(), Lang, { - v.device()->Exec( - [M, v](Context *ctx) { DGMM<DType, Lang>(true, *M, v, M, ctx); }, - {M->block(), v.block()}, {M->block()}); + v.device()->Exec([M, v](Context *ctx) { + DGMM<DType, Lang>(true, *M, v, M, ctx); + }, {M->block(), v.block()}, {M->block()}); }); } @@ -1239,9 +1268,9 @@ template <typename SType> void Bernoulli(const SType p, Tensor *out) { TYPE_LANG_SWITCH(out->data_type(), DType, out->device()->lang(), Lang, { auto prob = TypeCast<SType, DType>(p); - out->device()->Exec( - [prob, out](Context *ctx) { Bernoulli<DType, Lang>(prob, out, ctx); }, - {}, {out->block()}, true); + out->device()->Exec([prob, out](Context *ctx) { + Bernoulli<DType, Lang>(prob, out, ctx); + }, {}, {out->block()}, true); }); } @@ -1252,9 +1281,9 @@ void Uniform(const SType low, const SType high, Tensor *out) { TYPE_LANG_SWITCH(out->data_type(), DType, out->device()->lang(), Lang, { auto l = TypeCast<SType, DType>(low); auto h = TypeCast<SType, DType>(high); - out->device()->Exec( - [l, h, out](Context *ctx) { Uniform<DType, Lang>(l, h, out, ctx); }, {}, - {out->block()}, true); + out->device()->Exec([l, h, out](Context *ctx) { + Uniform<DType, Lang>(l, h, out, ctx); + }, {}, {out->block()}, true); }); } @@ -1265,9 +1294,9 @@ void Gaussian(const SType mean, const SType std, Tensor *out) { TYPE_LANG_SWITCH(out->data_type(), DType, out->device()->lang(), Lang, { auto m = TypeCast<SType, DType>(mean); auto s = TypeCast<SType, DType>(std); - out->device()->Exec( - [m, s, out](Context *ctx) { Gaussian<DType, Lang>(m, s, out, ctx); }, - {}, {out->block()}, true); + out->device()->Exec([m, s, out](Context *ctx) { + Gaussian<DType, Lang>(m, s, out, ctx); + }, {}, {out->block()}, true); }); } template void Gaussian<float>(const float mean, const float std, Tensor *out); @@ -1278,9 +1307,9 @@ template <typename SType> void Axpy(const SType alpha, const Tensor &in, Tensor *out) { TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, { auto a = TypeCast<SType, DType>(alpha); - out->device()->Exec( - [a, in, out](Context *ctx) { Axpy<DType, Lang>(a, in, out, ctx); }, - {in.block(), out->block()}, {out->block()}); + out->device()->Exec([a, in, out](Context *ctx) { + Axpy<DType, Lang>(a, in, out, ctx); + }, {in.block(), out->block()}, {out->block()}); }); } @@ -1307,22 +1336,18 @@ void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta, TYPE_LANG_SWITCH(A.data_type(), DType, A.device()->lang(), Lang, { auto a = TypeCast<SType, DType>(alpha); auto b = TypeCast<SType, DType>(beta); - C->device()->Exec( - [a, A, b, B, C](Context *ctx) { - GEMV<DType, Lang>(a, A, B, b, C, ctx); - }, - {A.block(), B.block()}, {C->block()}); + C->device()->Exec([a, A, b, B, C](Context *ctx) { + GEMV<DType, Lang>(a, A, B, b, C, ctx); + }, {A.block(), B.block()}, {C->block()}); }); } else { CHECK(!C->transpose()); TYPE_LANG_SWITCH(A.data_type(), DType, A.device()->lang(), Lang, { auto a = TypeCast<SType, DType>(alpha); auto b = TypeCast<SType, DType>(beta); - C->device()->Exec( - [a, A, b, B, C](Context *ctx) { - GEMM<DType, Lang>(a, A, B, b, C, ctx); - }, - {A.block(), B.block()}, {C->block()}); + C->device()->Exec([a, A, b, B, C](Context *ctx) { + GEMM<DType, Lang>(a, A, B, b, C, ctx); + }, {A.block(), B.block()}, {C->block()}); }); } } @@ -1349,14 +1374,11 @@ void ComputeCrossEntropy(const Tensor &p, const Tensor &t, Tensor *loss) { if (p.nDim() == 2u) batchsize = p.shape(0); size_t dim = p.Size() / batchsize; TYPE_LANG_SWITCH(p.data_type(), DType, p.device()->lang(), Lang, { - p.device()->Exec( - [batchsize, dim, t, p, loss](Context *ctx) { - bool int_target = t.Size() == batchsize; - ComputeCrossEntropy<DType, Lang>(int_target, batchsize, dim, - p.block(), t.block(), loss->block(), - ctx); - }, - {p.block(), t.block()}, {loss->block()}); + p.device()->Exec([batchsize, dim, t, p, loss](Context *ctx) { + bool int_target = t.Size() == batchsize; + ComputeCrossEntropy<DType, Lang>(int_target, batchsize, dim, p.block(), + t.block(), loss->block(), ctx); + }, {p.block(), t.block()}, {loss->block()}); }); } @@ -1367,14 +1389,11 @@ void SoftmaxCrossEntropyBwd(const Tensor &t, Tensor *p) { if (p->nDim() == 2u) batchsize = p->shape(0); size_t dim = p->Size() / batchsize; TYPE_LANG_SWITCH(p->data_type(), DType, p->device()->lang(), Lang, { - p->device()->Exec( - [batchsize, dim, t, p](Context *ctx) { - bool int_target = t.Size() == batchsize; - SoftmaxCrossEntropyBwd<DType, Lang>(int_target, batchsize, dim, - p->block(), t.block(), p->block(), - ctx); - }, - {p->block(), t.block()}, {p->block()}); + p->device()->Exec([batchsize, dim, t, p](Context *ctx) { + bool int_target = t.Size() == batchsize; + SoftmaxCrossEntropyBwd<DType, Lang>( + int_target, batchsize, dim, p->block(), t.block(), p->block(), ctx); + }, {p->block(), t.block()}, {p->block()}); }); } diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h index a9b5c70..aef4a59 100644 --- a/src/core/tensor/tensor_math.h +++ b/src/core/tensor/tensor_math.h @@ -369,8 +369,7 @@ void Dot(const Tensor &in1, const Tensor &in2, DType *out, Context *ctx) { LOG(FATAL) << "Dot Not Implemented"; } template <typename DType, typename Lang> -void Dot(const Tensor &in1, const Tensor &in2, Tensor *out, - Context *ctx) { +void Dot(const Tensor &in1, const Tensor &in2, Tensor *out, Context *ctx) { LOG(FATAL) << "Dot Not Implemented"; } @@ -404,7 +403,8 @@ void SoftMax(const Tensor &in, Tensor *out, Context *ctx) { } template <typename DType, typename Lang> -void SoftMax(const Tensor &in, Tensor *out, Context *ctx, int axis) { +void SoftMaxBackward(const Tensor &in, Tensor *out, const Tensor &fdout, + Context *ctx) { LOG(FATAL) << "Not Implemented"; } diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h index b592ecc..fb42576 100644 --- a/src/core/tensor/tensor_math_cpp.h +++ b/src/core/tensor/tensor_math_cpp.h @@ -240,36 +240,11 @@ void Abs<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) { #ifdef USE_DNNL template <> -void SoftMax<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx, - int axis) { - CHECK_EQ(in.device()->lang(), kCpp); - - CHECK_LE(axis, (int)in.shape().size() - 1); - CHECK_GE(axis, -1 * (int)in.nDim()); - - Shape original_shape = in.shape(); - if (axis < 0) axis = in.shape().size() + axis; - - Shape coerced_shape = {1, 1}; - for (int i = 0; i < in.shape().size(); i++) { - if (i < axis) - coerced_shape[0] *= in.shape()[i]; - else - coerced_shape[1] *= in.shape()[i]; - } - Tensor in_reshaped = Reshape(in, coerced_shape); - out->Reshape(coerced_shape); - - // optimise by minus x - x.max() - auto in_max = RowMax(in_reshaped); - in_max.Reshape({coerced_shape[0], 1}); - in_reshaped = in_reshaped - in_max; - - auto md = dnnl::memory::desc({coerced_shape[0], coerced_shape[1]}, +void SoftMax<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) { + auto md = dnnl::memory::desc({in.shape()[0], in.shape()[1]}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::ab); - auto in_mem = - dnnl::memory(md, ctx->dnnl_engine, in_reshaped.block()->mutable_data()); + auto in_mem = dnnl::memory(md, ctx->dnnl_engine, in.block()->mutable_data()); auto out_mem = dnnl::memory(md, ctx->dnnl_engine, out->block()->mutable_data()); @@ -281,9 +256,35 @@ void SoftMax<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx, softmax.execute(ctx->dnnl_stream, {{DNNL_ARG_SRC, in_mem}, {DNNL_ARG_DST, out_mem}}); ctx->dnnl_stream.wait(); +} + +template <> +void SoftMaxBackward<float, lang::Cpp>(const Tensor &in, Tensor *out, + const Tensor &fdout, Context *ctx) { + auto md = dnnl::memory::desc({in.shape()[0], in.shape()[1]}, + dnnl::memory::data_type::f32, + dnnl::memory::format_tag::ab); + auto in_mem = dnnl::memory(md, ctx->dnnl_engine, in.block()->mutable_data()); + auto fdout_mem = + dnnl::memory(md, ctx->dnnl_engine, fdout.block()->mutable_data()); + auto out_mem = + dnnl::memory(md, ctx->dnnl_engine, out->block()->mutable_data()); - out->Reshape(original_shape); + auto softmax_desc = + dnnl::softmax_forward::desc(dnnl::prop_kind::forward_scoring, md, 1); + auto softmax_prim_desc = + dnnl::softmax_forward::primitive_desc(softmax_desc, ctx->dnnl_engine); + + auto softmaxbwd_desc = dnnl::softmax_backward::desc(md, md, 1); + auto softmaxbwd_prim_desc = dnnl::softmax_backward::primitive_desc( + softmaxbwd_desc, ctx->dnnl_engine, softmax_prim_desc); + auto softmaxbwd = dnnl::softmax_backward(softmaxbwd_prim_desc); + softmaxbwd.execute(ctx->dnnl_stream, {{DNNL_ARG_DIFF_SRC, out_mem}, + {DNNL_ARG_DIFF_DST, in_mem}, + {DNNL_ARG_DST, fdout_mem}}); + ctx->dnnl_stream.wait(); } + #endif // USE_DNNL template <> @@ -927,6 +928,8 @@ void RowMax<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) { } } +// =========Matrix operations ================================================ +/* template <> void SoftMax<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) { CHECK_LE(in.nDim(), 2u) @@ -947,8 +950,6 @@ void SoftMax<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) { out->Reshape(in.shape()); } -// =========Matrix operations ================================================ -/* template <> void AddCol<float, lang::Cpp>(const size_t nrow, const size_t ncol, const Tensor& A, const Tensor& v, Tensor* out, diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h index 0a0b685..4b16af0 100644 --- a/src/core/tensor/tensor_math_cuda.h +++ b/src/core/tensor/tensor_math_cuda.h @@ -815,8 +815,8 @@ void Dot<float, lang::Cuda>(const Tensor& in1, const Tensor& in2, float* out, CUBLAS_CHECK(cublasSdot(handle, num, inPtr1, 1, inPtr2, 1, out)); } template <> -void Dot<float, lang::Cuda>(const Tensor& in1, - const Tensor& in2, Tensor* out, Context* ctx) { +void Dot<float, lang::Cuda>(const Tensor& in1, const Tensor& in2, Tensor* out, + Context* ctx) { const float* inPtr1 = static_cast<const float*>(in1.block()->data()); const float* inPtr2 = static_cast<const float*>(in2.block()->data()); float* outPtr = static_cast<float*>(out->block()->mutable_data()); @@ -828,8 +828,7 @@ void Dot<float, lang::Cuda>(const Tensor& in1, } template <> -void Nrm2<float, lang::Cuda>(const Tensor& in, float* out, - Context* ctx) { +void Nrm2<float, lang::Cuda>(const Tensor& in, float* out, Context* ctx) { auto handle = ctx->cublas_handle; // TODO(wangwei) set cudastream const float* inPtr = static_cast<const float*>(in.block()->data()); const size_t num = in.Size(); @@ -937,6 +936,41 @@ void SoftMax<float, lang::Cuda>(const Tensor& in, Tensor* out, Context* ctx) { } template <> +void SoftMaxBackward<float, lang::Cuda>(const Tensor& in, Tensor* out, + const Tensor& fdout, Context* ctx) { + cudnnSoftmaxAlgorithm_t algorithm = CUDNN_SOFTMAX_FAST; + cudnnSoftmaxMode_t mode = CUDNN_SOFTMAX_MODE_INSTANCE; + + /* + * tensor tmp is for generating cudnn descriptor + * as for cudnn softmax, it required shape of {N, C, 1, 1} + * while helper func `generate_shape_cuda` generate shape of {1, 1, N, C} + * Thus this part serve similar purpose as `generate_shape_cuda` but in + * reverse manner + */ + CHECK_LE(in.shape().size(), 5) + << "Dimensions (shape) beyond 5 are currently not supported"; + auto tmp = in; + while (tmp.shape().size() < 4) { + auto s = tmp.shape(); + s.push_back(1); + tmp.Reshape(s); + } + + const float* inPtr = static_cast<const float*>(in.block()->data()); + const float* fdoutPtr = static_cast<const float*>(fdout.block()->data()); + float* outPtr = static_cast<float*>(out->block()->mutable_data()); + + float alpha = 1.0; + float beta = 0.0; + + check_cudnn(cudnnSoftmaxBackward( + ctx->cudnn_handle, algorithm, mode, (void*)(&alpha), + generate_tensor_nd_desc(tmp), fdoutPtr, generate_tensor_nd_desc(tmp), + inPtr, (void*)(&beta), generate_tensor_nd_desc(tmp), outPtr)); +} + +template <> void ComputeCrossEntropy<float, lang::Cuda>(bool int_target, const size_t batchsize, const size_t dim, const Block* p, diff --git a/test/python/test_api.py b/test/python/test_api.py index 197f884..518c4f9 100644 --- a/test/python/test_api.py +++ b/test/python/test_api.py @@ -340,27 +340,22 @@ class TestAPI(unittest.TestCase): hndl = singa_api.BatchNormHandle( m_0, tensor.Tensor(device=dev, data=x_0).data) - (y_2_c, rm_2_c, rv_2_c, bm_2_c, - bv_2_c) = singa_api.CpuBatchNormForwardTraining( - hndl, - tensor.Tensor(device=dev, data=x_0).data, - tensor.Tensor(device=dev, data=s_0).data, - tensor.Tensor(device=dev, data=b_0).data, - tensor.Tensor(device=dev, data=rm_0).data, - tensor.Tensor(device=dev, data=rv_0).data) + (y_2_c, bm_2_c, bv_2_c) = singa_api.CpuBatchNormForwardTraining( + hndl, + tensor.Tensor(device=dev, data=x_0).data, + tensor.Tensor(device=dev, data=s_0).data, + tensor.Tensor(device=dev, data=b_0).data, + tensor.Tensor(device=dev, data=rm_0).data, + tensor.Tensor(device=dev, data=rv_0).data) np.testing.assert_array_almost_equal( y_1, tensor.to_numpy(_cTensor_to_pyTensor(y_2_c))) - #np.testing.assert_array_almost_equal( - # bm_1, tensor.to_numpy(_cTensor_to_pyTensor(bm_2_c))) np.testing.assert_array_almost_equal( - rm_1, tensor.to_numpy(_cTensor_to_pyTensor(rm_2_c))) + bm_1, tensor.to_numpy(_cTensor_to_pyTensor(bm_2_c))) #print(bv_1) #print(tensor.to_numpy(_cTensor_to_pyTensor(bv_2_c))) #np.testing.assert_array_almost_equal( # bv_1, tensor.to_numpy(_cTensor_to_pyTensor(bv_2_c)), decimal=3) - np.testing.assert_array_almost_equal( - rv_1, tensor.to_numpy(_cTensor_to_pyTensor(rv_2_c)), decimal=4) return x_0 = np.array([1, 1, 1, 1, 2, 2, 2, 2, 10, 10, 10, 10, 20, 20, 20, 20],