reformat the code
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/3e2b75cb Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/3e2b75cb Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/3e2b75cb Branch: refs/heads/master Commit: 3e2b75cbe86908f551ac3f492a8aba07008b227b Parents: c52e2aa Author: Wang Wei <[email protected]> Authored: Sun May 13 20:42:52 2018 +0800 Committer: Wang Wei <[email protected]> Committed: Sun May 13 20:42:52 2018 +0800 ---------------------------------------------------------------------- include/singa/core/tensor.h | 55 +++--- src/core/tensor/tensor.cc | 291 ++++++++++++++++---------------- src/core/tensor/tensor_math_cpp.h | 163 +++++++++--------- src/core/tensor/tensor_math_cuda.h | 286 +++++++++++++++---------------- 4 files changed, 403 insertions(+), 392 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2b75cb/include/singa/core/tensor.h ---------------------------------------------------------------------- diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h index e25aafd..3cc28ff 100644 --- a/include/singa/core/tensor.h +++ b/include/singa/core/tensor.h @@ -36,7 +36,8 @@ typedef vector<size_t> Shape; /// hardcode the width of types defined in DataType const size_t kDataWidth[] = {sizeof(float), sizeof(float) / 2, sizeof(int), sizeof(char), - sizeof(double), sizeof(unsigned char)}; + sizeof(double), sizeof(unsigned char) + }; inline size_t SizeOf(DataType t) { static_assert(kNumDataType == sizeof(kDataWidth) / sizeof(size_t), "Num of data types not match num of data width"); @@ -51,7 +52,7 @@ inline size_t SizeOf(DataType t) { /// Tensor. /// For all operations, if the result tensor is passed as an argument, /// then it must be set up correctly (shape, device). Otherwise, runtime error -/// like SegmentFault would happen. Simply type/device check would be conducted. +/// like SegmentFault would happen. Simple type/device check would be conducted. class Tensor { public: ~Tensor(); @@ -59,12 +60,17 @@ class Tensor { explicit Tensor(Shape &&shape, DataType dtype = kFloat32); explicit Tensor(const Shape &shape, DataType dtype = kFloat32); - Tensor(Shape &&shape, std::shared_ptr<Device> dev, DataType dtype = kFloat32); - Tensor(const Shape &shape, std::shared_ptr<Device> dev, DataType dtype = kFloat32); + Tensor(Shape &&shape, + std::shared_ptr<Device> dev, + DataType dtype = kFloat32); + Tensor(const Shape &shape, + std::shared_ptr<Device> dev, + DataType dtype = kFloat32); /// Copy Tensor to share the internal data. No deep copy. Tensor(const Tensor &from); - /// Copy Tensor to share the internal data. No deep copy. For 2 tensors sharing same block but different strides. + /// Copy Tensor to share the internal data. No deep copy. + /// For 2 tensors sharing same block but different strides. Tensor(const Tensor &from, Shape &new_shape, vector<int> &new_strides); /// Copy Tensor to share the internal data. No deep copy. Tensor(Tensor &&from); @@ -89,7 +95,7 @@ class Tensor { void GetValue(SType *value, const size_t num) { CHECK(device_ == defaultDevice); const SType* ptr = data<SType>(); - for(size_t i = 0; i < num; i++) value[i] = ptr[i]; + for (size_t i = 0; i < num; i++) value[i] = ptr[i]; } /// data type, including kFloat16, kFloat32, kInt @@ -106,7 +112,7 @@ class Tensor { bool empty() const { return nDim() == 0; } - //bool transpose() const { return transpose_; } + /// Check if the tensor's last stride==1 bool transpose() const { return (strides_.back() != 1); } const vector<int>& strides() const { return strides_; } @@ -131,9 +137,8 @@ class Tensor { void Reshape(Shape &&shape); /// Reset the shape, device, and data type as given tensor. - /// If block size changes, then reallocate a new block. The previous block - /// would - /// be deleted. + /// If block size changes, then reallocate a new block. + /// The previous block would be deleted. void ResetLike(const Tensor &t); /// Reset the data type, it would reallocate block if type changes. @@ -176,9 +181,11 @@ class Tensor { /// No data copy, just set the transpose_ filed of the returned tensor. Tensor T() const; + /// Reverse the shape vector Tensor Transpose() const; - Tensor Transpose(Shape axes) const; + /// Change the axes + Tensor Transpose(const vector<size_t>& axes) const; /// Copy the meta info with data block shared. Tensor &operator=(const Tensor &in); @@ -219,23 +226,24 @@ class Tensor { float L2() const; //generate strides automatically if stride field is not passed -void generate_strides(){ - if(shape_.size()==0){ - strides_ = {1}; - return void(); - } + void generate_strides() { strides_.clear(); + if (shape_.size() == 0) { + strides_.push_back(1); + return; + } + size_t dim = Size(); int cumulative_product = 1; - for (size_t n=0; n<shape_.size(); ++n) { - cumulative_product = cumulative_product*shape_[n]; - strides_.push_back(dim/cumulative_product); + for (size_t n = 0; n < shape_.size(); ++n) { + cumulative_product = cumulative_product * shape_[n]; + strides_.push_back(dim / cumulative_product); } -}; + } -void set_strides(const vector<int> new_strides){ - strides_ = new_strides; -} + void set_strides(const vector<int> new_strides) { + strides_ = new_strides; + } protected: DataType data_type_ = kFloat32; @@ -247,7 +255,6 @@ void set_strides(const vector<int> new_strides){ vector<int> strides_ = {}; }; //end of tensor class -typedef Shape::iterator ShapeIter; inline size_t Product(const Shape &shape, int start = 0, size_t len = 0) { if (len == 0) len = shape.size(); if (len == 0) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2b75cb/src/core/tensor/tensor.cc ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index a4efd64..d98e6a6 100644 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -21,7 +21,6 @@ #include "./tensor_math_cuda.h" #include "./tensor_math_opencl.h" #include <utility> -#include <iostream> namespace singa { @@ -31,21 +30,21 @@ Tensor::~Tensor() { block_ = nullptr; } -Tensor::Tensor() { +Tensor::Tensor() { device_ = defaultDevice; strides_ = {1}; } -//non-strided constructors +//non-strided constructors Tensor::Tensor(const Shape &shape, DataType dtype) - : data_type_(dtype), device_(defaultDevice), shape_(shape) { + : data_type_(dtype), device_(defaultDevice), shape_(shape) { size_t size = Product(shape_) * SizeOf(data_type_); if (size) block_ = device_->NewBlock((int)size); generate_strides(); } Tensor::Tensor(Shape &&shape, DataType dtype) - : data_type_(dtype), device_(defaultDevice), shape_(shape) { + : data_type_(dtype), device_(defaultDevice), shape_(shape) { size_t size = Product(shape_) * SizeOf(data_type_); if (size) block_ = device_->NewBlock((int)size); @@ -55,14 +54,14 @@ Tensor::Tensor(Shape &&shape, DataType dtype) //non-strided constructors with device Tensor::Tensor(const Shape &shape, std::shared_ptr<Device> device, DataType dtype) - : data_type_(dtype), device_(device), shape_(shape) { + : data_type_(dtype), device_(device), shape_(shape) { size_t size = Product(shape_) * SizeOf(data_type_); if (size) block_ = device_->NewBlock((int)size); generate_strides(); } Tensor::Tensor(Shape &&shape, std::shared_ptr<Device> device, DataType dtype) - : data_type_(dtype), device_(device), shape_(shape) { + : data_type_(dtype), device_(device), shape_(shape) { size_t size = Product(shape_) * SizeOf(data_type_); if (size) block_ = device_->NewBlock((int)size); @@ -71,34 +70,34 @@ Tensor::Tensor(Shape &&shape, std::shared_ptr<Device> device, DataType dtype) Tensor::Tensor(const Tensor &in) - : //transpose_(in.transpose_), - data_type_(in.data_type_), - device_(in.device_), - block_(in.block()), - shape_(in.shape_), - strides_(in.strides_) { + : //transpose_(in.transpose_), + data_type_(in.data_type_), + device_(in.device_), + block_(in.block()), + shape_(in.shape_), + strides_(in.strides_) { if (block_ != nullptr) block_->IncRefCount(); } //strided constructor taking in a tensor, shape and strides Tensor::Tensor(const Tensor &in, Shape &new_shape, vector<int> &new_strides) - : //transpose_(in.transpose_), - data_type_(in.data_type_), - device_(in.device_), - block_(in.block()), - shape_(new_shape), - strides_(new_strides) { + : //transpose_(in.transpose_), + data_type_(in.data_type_), + device_(in.device_), + block_(in.block()), + shape_(new_shape), + strides_(new_strides) { if (block_ != nullptr) block_->IncRefCount(); } Tensor::Tensor(Tensor &&in) - : //transpose_(in.transpose_), - data_type_(in.data_type_), - device_(in.device_), - shape_(std::move(in.shape_)), - strides_(in.strides_) { + : //transpose_(in.transpose_), + data_type_(in.data_type_), + device_(in.device_), + shape_(std::move(in.shape_)), + strides_(in.strides_) { block_ = in.block_; in.block_ = nullptr; } @@ -123,10 +122,13 @@ void Tensor::ResetLike(const Tensor &in) { strides_ = in.strides_; } -//if tensor is not transposed yet i.e strides == 1, then we simply change the shape and generate new default strides -//if tensor is already transposed i.e strides != 1, it should be copied to a new tensor with newly generated default strides +// if tensor is not transposed yet i.e strides == 1, +// then we simply change the shape and generate new default strides +// if tensor is already transposed i.e strides != 1, +// it should be copied to a new tensor with newly generated default strides +// TODO(wangwei) raise error if the shape not match void Tensor::Reshape(const Shape &shape) { - if(strides_.size()==0) + if (strides_.size() == 0) strides_.push_back(1); if (Product(shape_) != Product(shape)) { @@ -141,7 +143,7 @@ void Tensor::Reshape(const Shape &shape) { } void Tensor::Reshape(Shape &&shape) { - if(strides_.size()==0) + if (strides_.size() == 0) strides_.push_back(1); if (Product(shape_) != Product(shape)) { @@ -196,12 +198,12 @@ void Tensor::CopyDataFromHostPtr(const DType *src, const size_t num, } } template void Tensor::CopyDataFromHostPtr(const unsigned char *src, - const size_t num, - const size_t offset); + const size_t num, + const size_t offset); template void Tensor::CopyDataFromHostPtr(const float *src, const size_t num, - const size_t offset); + const size_t offset); template void Tensor::CopyDataFromHostPtr(const int *src, const size_t num, - const size_t offset); + const size_t offset); void Tensor::CopyData(const Tensor &src) { CHECK_EQ(Size(), src.Size()); @@ -224,44 +226,44 @@ void Tensor::FromProto(const singa::TensorProto &proto) { strides_.clear(); for (int32_t s : proto.strides()) strides_.push_back(s); switch (data_type_) { - case kFloat32: { - std::unique_ptr<float[]> data_ptr(new float[Product(shape_)]); - for (size_t i = 0; i < Product(shape_); ++i) - data_ptr[i] = static_cast<float>(proto.float_data((int)i)); - CopyDataFromHostPtr<float>(data_ptr.get(), Product(shape_)); - break; - } - case kDouble: { - std::unique_ptr<double[]> data(new double[Product(shape_)]); - for (size_t i = 0; i < Product(shape_); ++i) - data[i] = proto.double_data((int)i); - CopyDataFromHostPtr<double>(data.get(), Product(shape_)); - break; - } - case kInt: { - std::unique_ptr<int[]> data(new int[Product(shape_)]); - for (size_t i = 0; i < Product(shape_); ++i) data[i] = proto.int_data((int)i); - CopyDataFromHostPtr<int>(data.get(), Product(shape_)); - break; - } - ///TODO(wangji): Implement to support C++ type char using bytes type in protobuf - /// which is equivalent to string type is different from the other cases. The kchar - /// and kUChar case is to be implemented. - /* - case kChar: { - std::unique_ptr<char[]> data(new char[Product(shape_)]); - for (size_t i = 0; i < Product(shape_); ++i) - data[i] = static_cast<char>(proto.bytes_data(i)); - break; - } - case kUChar: { - std::unique_ptr<unsigned char[]> data(new unsigned char[Product(shape_)]); - for (size_t i = 0; i < Product(shape_); ++i) - data[i] = static_cast<unsigned char>(proto.bytes_data(i)); - break; - } - */ - default: { LOG(FATAL) << "Unsupported Type" << DataType_Name(data_type_); } + case kFloat32: { + std::unique_ptr<float[]> data_ptr(new float[Product(shape_)]); + for (size_t i = 0; i < Product(shape_); ++i) + data_ptr[i] = static_cast<float>(proto.float_data((int)i)); + CopyDataFromHostPtr<float>(data_ptr.get(), Product(shape_)); + break; + } + case kDouble: { + std::unique_ptr<double[]> data(new double[Product(shape_)]); + for (size_t i = 0; i < Product(shape_); ++i) + data[i] = proto.double_data((int)i); + CopyDataFromHostPtr<double>(data.get(), Product(shape_)); + break; + } + case kInt: { + std::unique_ptr<int[]> data(new int[Product(shape_)]); + for (size_t i = 0; i < Product(shape_); ++i) data[i] = proto.int_data((int)i); + CopyDataFromHostPtr<int>(data.get(), Product(shape_)); + break; + } + ///TODO(wangji): Implement to support C++ type char using bytes type in protobuf + /// which is equivalent to string type is different from the other cases. The kchar + /// and kUChar case is to be implemented. + /* + case kChar: { + std::unique_ptr<char[]> data(new char[Product(shape_)]); + for (size_t i = 0; i < Product(shape_); ++i) + data[i] = static_cast<char>(proto.bytes_data(i)); + break; + } + case kUChar: { + std::unique_ptr<unsigned char[]> data(new unsigned char[Product(shape_)]); + for (size_t i = 0; i < Product(shape_); ++i) + data[i] = static_cast<unsigned char>(proto.bytes_data(i)); + break; + } + */ + default: { LOG(FATAL) << "Unsupported Type" << DataType_Name(data_type_); } } } @@ -277,44 +279,44 @@ void Tensor::ToProto(singa::TensorProto *proto) const { proto->add_strides(s); } switch (data_type_) { - case kFloat32: { - proto->clear_float_data(); - const float *data_ptr = data<float>(); - for (size_t i = 0; i < Product(shape_); ++i) - proto->add_float_data(data_ptr[i]); - break; - } - case kDouble: { - proto->clear_double_data(); - const double *data_ptr = data<double>(); - for (size_t i = 0; i < Product(shape_); ++i) - proto->add_double_data(data_ptr[i]); - break; - } - case kInt: { - proto->clear_int_data(); - const int *data_ptr = data<int>(); - for (size_t i = 0; i < Product(shape_); ++i) - proto->add_int_data(data_ptr[i]); - break; - } - /* - case kChar: { - proto->clear_bytes_data(); - const char *data = data<char>(); - for (size_t i = 0; i < Product(shape_); ++i) - proto->add_bytes_data(static_cast<unsigned char>(data[i])); - break; - } - case kUChar: { - proto->clear_bytes_data(); - const unsigned char *data = data<unsigned char>(); - for (size_t i = 0; i < Product(shape_); ++i) - proto->add_bytes_data(static_cast<unsigned char>(data[i])); - break; - } - */ - default: { LOG(FATAL) << "Unsupported Type" << DataType_Name(data_type_); } + case kFloat32: { + proto->clear_float_data(); + const float *data_ptr = data<float>(); + for (size_t i = 0; i < Product(shape_); ++i) + proto->add_float_data(data_ptr[i]); + break; + } + case kDouble: { + proto->clear_double_data(); + const double *data_ptr = data<double>(); + for (size_t i = 0; i < Product(shape_); ++i) + proto->add_double_data(data_ptr[i]); + break; + } + case kInt: { + proto->clear_int_data(); + const int *data_ptr = data<int>(); + for (size_t i = 0; i < Product(shape_); ++i) + proto->add_int_data(data_ptr[i]); + break; + } + /* + case kChar: { + proto->clear_bytes_data(); + const char *data = data<char>(); + for (size_t i = 0; i < Product(shape_); ++i) + proto->add_bytes_data(static_cast<unsigned char>(data[i])); + break; + } + case kUChar: { + proto->clear_bytes_data(); + const unsigned char *data = data<unsigned char>(); + for (size_t i = 0; i < Product(shape_); ++i) + proto->add_bytes_data(static_cast<unsigned char>(data[i])); + break; + } + */ + default: { LOG(FATAL) << "Unsupported Type" << DataType_Name(data_type_); } } } @@ -353,9 +355,9 @@ Tensor Tensor::Transpose() const { t.device_ = device_; t.data_type_ = data_type_; t.strides_.clear(); - for(size_t n=0; n<shape_.size(); ++n){ - t.shape_.push_back(shape_[shape_.size()-n-1]); - t.strides_.push_back(strides_[shape_.size()-n-1]); + for (size_t n = 0; n < shape_.size(); ++n) { + t.shape_.push_back(shape_[shape_.size() - n - 1]); + t.strides_.push_back(strides_[shape_.size() - n - 1]); } t.block_ = block_; block_->IncRefCount(); @@ -363,6 +365,7 @@ Tensor Tensor::Transpose() const { } //transpose with axes +// TODO(wangwei) the shape and axes should match Tensor Tensor::Transpose(Shape axes) const { // if(axes.size() != shape_.size()){ // std::cout << "Warning: Size of input axes doesn't match size of shape" << std::endl; @@ -375,7 +378,7 @@ Tensor Tensor::Transpose(Shape axes) const { t.device_ = device_; t.data_type_ = data_type_; t.strides_.clear(); - for(size_t n=0; n<axes.size(); ++n){ + for (size_t n = 0; n < axes.size(); ++n) { t.shape_.push_back(shape_[axes[n]]); t.strides_.push_back(strides_[axes[n]]); } @@ -404,7 +407,7 @@ Tensor &Tensor::operator=(Tensor &&in) { if (block_ != nullptr && block_->DecRefCount() == 0) device_->FreeBlock(block_); //transpose_ = in.transpose_; - strides_ = in.strides_; + strides_ = std::move(in.strides_); data_type_ = in.data_type_; shape_ = std::move(in.shape_); device_ = in.device_; @@ -470,7 +473,7 @@ void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num, (int)s_offset); } else if (src_dev->lang() == kCpp) { dst_dev->CopyDataToFrom(to, from, nBytes, kHostToDevice, (int)d_offset, - (int)s_offset); + (int)s_offset); } else { LOG(FATAL) << "Not support mem copy betwee Cuda and OpenCL device"; } @@ -548,7 +551,7 @@ void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num, float Tensor::L1() const { float nrm = 0.0f; TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, { - device_->Exec([&nrm, this](Context *ctx) { + device_->Exec([&nrm, this](Context * ctx) { DType ret = DType(0); Asum<DType, Lang>(*this, &ret, ctx); nrm = TypeCast<DType, float>(ret); @@ -561,7 +564,7 @@ float Tensor::L1() const { float Tensor::L2() const { float nrm = 0.0f; TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, { - device_->Exec([&nrm, this](Context *ctx) { + device_->Exec([&nrm, this](Context * ctx) { DType ret = DType(0); Nrm2<DType, Lang>(*this, &ret, ctx); nrm = TypeCast<DType, float>(ret); @@ -577,7 +580,7 @@ void Tensor::SetValue(const SType x) { auto ptr = block_; TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, { // TODO(wangwei) cast x to DType - device_->Exec([this, x, ptr](Context *ctx) { + device_->Exec([this, x, ptr](Context * ctx) { Set<DType, Lang>(x, this, ctx); }, {}, {ptr}); }); @@ -691,7 +694,7 @@ void Div(const SType alpha, const Tensor &in, Tensor *out) { CHECK(in.shape() == out->shape()); TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, { // TODO(wangwei) type cast SType to DType; - in.device()->Exec([alpha, in, out](Context *ctx) { + in.device()->Exec([alpha, in, out](Context * ctx) { Div<DType, Lang>(alpha, in, out, ctx); }, {in.block()}, {out->block()}); }); @@ -727,7 +730,7 @@ float Sum<float>(const Tensor &in) { Tensor one(in.shape(), in.device(), in.data_type()); one.SetValue(1.0f); TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, { - one.device()->Exec([in, one, &s](Context *ctx) { + one.device()->Exec([in, one, &s](Context * ctx) { DType ret = DType(0); Dot<DType, Lang>(in, one, &ret, ctx); s = ret; @@ -758,7 +761,7 @@ Tensor SoftMax(const Tensor &in) { Tensor RowMax(const Tensor &in) { Tensor ret({in.shape(0)}, in.device(), in.data_type()); TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, { - in.device()->Exec([&in, &ret](Context *ctx) { + in.device()->Exec([&in, &ret](Context * ctx) { //size_t nrow = 1; //if (in.nDim() > 1) nrow = in.shape(0); //size_t ncol = in.Size() / nrow; @@ -805,7 +808,7 @@ void AddColumn(const SType alpha, const SType beta, const Tensor &v, Tensor vmat = Reshape(v, Shape{nb_row, 1}); Mult(alpha, vmat, one, beta, M); } -} +} template void AddColumn(const float alpha, const float beta, const Tensor &v, Tensor *M); @@ -846,16 +849,16 @@ Tensor ConcatOn(const vector<Tensor> &in, int axis) { CHECK_GE(dim, 2u) << " Only work for tensor of dim >=2 "; size_t size = in[0].Size() / in[0].shape(axis); size_t new_size = 0u; - for (const auto& t: in) { + for (const auto& t : in) { CHECK_EQ(dim, t.shape().size()) << "All tensors should have the same dim"; CHECK_EQ(size, t.Size() / t.shape(axis)) << "The size of all axis should " - <<" be the same except the concatenated axis"; + << " be the same except the concatenated axis"; new_size += t.shape(axis); } out_shape[axis] = new_size; if (axis == 0) { size_t nrow = 0; - for (const auto& t: in) { + for (const auto& t : in) { nrow += t.shape(0); tmp.push_back(Reshape(t, {t.shape(0), t.Size() / t.shape(0)})); } @@ -863,7 +866,7 @@ Tensor ConcatOn(const vector<Tensor> &in, int axis) { ret.Reshape(out_shape); return ret; } else { - for (const auto& t: in) { + for (const auto& t : in) { size_t nrow = 1; for (int i = 0; i < axis; i++) nrow *= t.shape(i); @@ -944,7 +947,7 @@ Tensor SliceOn(const Tensor&in, const size_t start, const size_t end, int axis) out_shape[axis] = end - start; if (axis == 0) { auto ret = SliceRows(Reshape(in, {in.shape(0), in.Size() / in.shape(0)}), - start, end); + start, end); ret.Reshape(out_shape); return ret; } else { @@ -953,7 +956,7 @@ Tensor SliceOn(const Tensor&in, const size_t start, const size_t end, int axis) nrow *= in.shape(i); auto suffix = in.Size() / nrow / in.shape(axis); auto ret = SliceColumns(Reshape(in, {nrow, in.Size() / nrow}), - start * suffix, end * suffix); + start * suffix, end * suffix); ret.Reshape(out_shape); return ret; } @@ -997,9 +1000,9 @@ void MultColumn(const Tensor &v, Tensor *M) { CHECK_EQ(v.Size(), M->shape(0)); CheckDataTypeAndLang(*M, v); TYPE_LANG_SWITCH(v.data_type(), DType, v.device()->lang(), Lang, { - v.device()->Exec([M, v](Context *ctx) { + v.device()->Exec([M, v](Context * ctx) { DGMM<DType, Lang>(false, *M, v, - M, ctx); + M, ctx); }, {M->block(), v.block()}, {M->block()}); }); } @@ -1012,9 +1015,9 @@ void MultRow(const Tensor &v, Tensor *M) { CHECK_EQ(v.Size(), M->shape(1)); CheckDataTypeAndLang(*M, v); TYPE_LANG_SWITCH(v.data_type(), DType, v.device()->lang(), Lang, { - v.device()->Exec([M, v](Context *ctx) { + v.device()->Exec([M, v](Context * ctx) { DGMM<DType, Lang>(true, *M, v, - M, ctx); + M, ctx); }, {M->block(), v.block()}, {M->block()}); }); } @@ -1059,7 +1062,7 @@ template <typename SType> void Bernoulli(const SType p, Tensor *out) { TYPE_LANG_SWITCH(out->data_type(), DType, out->device()->lang(), Lang, { auto prob = TypeCast<SType, DType>(p); - out->device()->Exec([prob, out](Context *ctx) { + out->device()->Exec([prob, out](Context * ctx) { Bernoulli<DType, Lang>(prob, out, ctx); }, {}, {out->block()}, true); }); @@ -1072,7 +1075,7 @@ void Uniform(const SType low, const SType high, Tensor *out) { TYPE_LANG_SWITCH(out->data_type(), DType, out->device()->lang(), Lang, { auto l = TypeCast<SType, DType>(low); auto h = TypeCast<SType, DType>(high); - out->device()->Exec([l, h, out](Context *ctx) { + out->device()->Exec([l, h, out](Context * ctx) { Uniform<DType, Lang>(l, h, out, ctx); }, {}, {out->block()}, true); }); @@ -1085,7 +1088,7 @@ void Gaussian(const SType mean, const SType std, Tensor *out) { TYPE_LANG_SWITCH(out->data_type(), DType, out->device()->lang(), Lang, { auto m = TypeCast<SType, DType>(mean); auto s = TypeCast<SType, DType>(std); - out->device()->Exec([m, s, out](Context *ctx) { + out->device()->Exec([m, s, out](Context * ctx) { Gaussian<DType, Lang>(m, s, out, ctx); }, {}, {out->block()}, true); }); @@ -1098,7 +1101,7 @@ template <typename SType> void Axpy(const SType alpha, const Tensor &in, Tensor *out) { TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, { auto a = TypeCast<SType, DType>(alpha); - out->device()->Exec([a, in, out](Context *ctx) { + out->device()->Exec([a, in, out](Context * ctx) { Axpy<DType, Lang>(a, in, out, ctx); }, {in.block(), out->block()}, {out->block()}); }); @@ -1128,7 +1131,7 @@ void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta, TYPE_LANG_SWITCH(A.data_type(), DType, A.device()->lang(), Lang, { auto a = TypeCast<SType, DType>(alpha); auto b = TypeCast<SType, DType>(beta); - C->device()->Exec([a, A, b, B, C](Context *ctx) { + C->device()->Exec([a, A, b, B, C](Context * ctx) { GEMV<DType, Lang>(a, A, B, b, C, ctx); }, {A.block(), B.block()}, {C->block()}); }); @@ -1137,9 +1140,9 @@ void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta, TYPE_LANG_SWITCH(A.data_type(), DType, A.device()->lang(), Lang, { auto a = TypeCast<SType, DType>(alpha); auto b = TypeCast<SType, DType>(beta); - C->device()->Exec([a, A, b, B, C](Context *ctx) { + C->device()->Exec([a, A, b, B, C](Context * ctx) { GEMM<DType, Lang>(a, A, B, b, C, - ctx); + ctx); }, {A.block(), B.block()}, {C->block()}); }); } @@ -1155,10 +1158,10 @@ void ComputeCrossEntropy(const Tensor &p, const Tensor &t, Tensor *loss) { if (p.nDim() == 2u) batchsize = p.shape(0); size_t dim = p.Size() / batchsize; TYPE_LANG_SWITCH(p.data_type(), DType, p.device()->lang(), Lang, { - p.device()->Exec([batchsize, dim, t, p, loss](Context *ctx) { - bool int_target = t.Size() == batchsize; - ComputeCrossEntropy<DType, Lang>(int_target, batchsize, dim, p.block(), - t.block(), loss->block(), ctx); + p.device()->Exec([batchsize, dim, t, p, loss](Context * ctx) { + bool int_target = t.Size() == batchsize; + ComputeCrossEntropy<DType, Lang>(int_target, batchsize, dim, p.block(), + t.block(), loss->block(), ctx); }, {p.block(), t.block()}, {loss->block()}); }); } @@ -1170,10 +1173,10 @@ void SoftmaxCrossEntropyBwd(const Tensor &t, Tensor *p) { if (p->nDim() == 2u) batchsize = p->shape(0); size_t dim = p->Size() / batchsize; TYPE_LANG_SWITCH(p->data_type(), DType, p->device()->lang(), Lang, { - p->device()->Exec([batchsize, dim, t, p](Context *ctx) { + p->device()->Exec([batchsize, dim, t, p](Context * ctx) { bool int_target = t.Size() == batchsize; SoftmaxCrossEntropyBwd<DType, Lang>(int_target, batchsize, dim, - p->block(), t.block(), p->block(), ctx); + p->block(), t.block(), p->block(), ctx); }, {p->block(), t.block()}, {p->block()}); }); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2b75cb/src/core/tensor/tensor_math_cpp.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h index 1ca312a..bfdd026 100644 --- a/src/core/tensor/tensor_math_cpp.h +++ b/src/core/tensor/tensor_math_cpp.h @@ -32,13 +32,14 @@ namespace singa { // ===================== Helper Functions ============================= -//generate a traversal_info vector based on the tensor's shape for the traverse_next function to work +// generate a traversal_info vector based on the tensor's shape for the +// traverse_next function to work vector<int> generate_traversal_info(const Tensor& x) { - vector<int> traversal_info = {}; - for(size_t n=0; n<(x.shape().size()+2); ++n) { - traversal_info.push_back(0); - } - return traversal_info; + vector<int> traversal_info = {}; + for (size_t n = 0; n < (x.shape().size() + 2); ++n) { + traversal_info.push_back(0); + } + return traversal_info; }; //generate shape multipliers @@ -47,18 +48,18 @@ vector<int> generate_traversal_info(const Tensor& x) { //this means that the 3rd, 6th, and 9th index of the array will always be the starting element of their respective rows //so we need to need use the inner stride when jumping from 1st->2nd element, and outer stride when jumping from 2nd->3rd vector<int> generate_shape_multipliers(const Tensor& x) { - Shape y_shape = x.shape(); - if(y_shape.size()==0){ - return {1}; - } - vector<int> shape_multipliers = {1}; - int cumulative_product = 1; + Shape y_shape = x.shape(); + if (y_shape.size() == 0) { + return {1}; + } + vector<int> shape_multipliers = {1}; + int cumulative_product = 1; - for (size_t n=0; n<(y_shape.size()-1); ++n) { - cumulative_product = cumulative_product*y_shape[y_shape.size()-1-n]; - shape_multipliers.insert(shape_multipliers.begin(), cumulative_product); - } - return shape_multipliers; + for (size_t n = 0; n < (y_shape.size() - 1); ++n) { + cumulative_product = cumulative_product * y_shape[y_shape.size() - 1 - n]; + shape_multipliers.insert(shape_multipliers.begin(), cumulative_product); + } + return shape_multipliers; }; // ****************************************************************************************** @@ -71,20 +72,20 @@ vector<int> generate_shape_multipliers(const Tensor& x) { //this additional check only has 1 loop for 2d matrix //but runtime performance might degrade to O(nlog(n)) for higher dimensional tensors int determine_order(vector<int>& shape_multipliers, int counter) { - for (size_t n=0; n<(shape_multipliers.size()-1); ++n) { - if((counter%shape_multipliers[n])==0){ - return ((shape_multipliers.size()) - 1 - n); - } + for (size_t n = 0; n < (shape_multipliers.size() - 1); ++n) { + if ((counter % shape_multipliers[n]) == 0) { + return ((shape_multipliers.size()) - 1 - n); } - return 0; + } + return 0; }; //this function updates the base indexes with the current index after every single traversal step, //can be generalized beyond 2d cases void update_base_index(const Tensor& x, vector<int>& traversal_info) { - for (int n=0; n<(traversal_info[x.shape().size()+1]+1); ++n) { - traversal_info[n] = traversal_info[x.shape().size()]; - } + for (int n = 0; n < (traversal_info[x.shape().size() + 1] + 1); ++n) { + traversal_info[n] = traversal_info[x.shape().size()]; + } }; //function to traverse a const strided tensor object @@ -95,32 +96,32 @@ void update_base_index(const Tensor& x, vector<int>& traversal_info) { //index 3 stores the order of the traversal for e.g. if the order is 0, //it means the next element can be navigated to using the innermost stride void traverse_next(const Tensor& x, - vector<int>& shape_multipliers, + vector<int>& shape_multipliers, vector<int>& traversal_info, int counter) { - update_base_index(x, traversal_info); - traversal_info[x.shape().size()+1] = determine_order(shape_multipliers, counter); - traversal_info[x.shape().size()] = traversal_info[traversal_info[x.shape().size()+1]] + - x.strides()[x.strides().size()-traversal_info[x.shape().size()+1]-1]; + update_base_index(x, traversal_info); + traversal_info[x.shape().size() + 1] = determine_order(shape_multipliers, counter); + traversal_info[x.shape().size()] = traversal_info[traversal_info[x.shape().size() + 1]] + + x.strides()[x.strides().size() - traversal_info[x.shape().size() + 1] - 1]; }; template <typename DType> -void TraverseUnary(const Tensor &in, Tensor* out, std::function<DType(DType)> func){ +void TraverseUnary(const Tensor &in, Tensor* out, std::function<DType(DType)> func) { DType *outPtr = static_cast<DType *>(out->block()->mutable_data()); const DType *inPtr = static_cast<const DType *>(in.block()->data()); vector<int> traversal_info = generate_traversal_info(in); vector<int> shape_multipliers = generate_shape_multipliers(in); - for (size_t i = 0; i < in.Size(); i++) { + for (size_t i = 0; i < in.Size(); i++) { outPtr[i] = func(inPtr[traversal_info[in.shape().size()]]); - traverse_next(in, shape_multipliers, traversal_info, i+1); + traverse_next(in, shape_multipliers, traversal_info, i + 1); } } template <typename DType> -void TraverseBinary(const Tensor &in1, const Tensor &in2, Tensor* out, - std::function<DType(DType, DType)> func){ +void TraverseBinary(const Tensor &in1, const Tensor &in2, Tensor* out, + std::function<DType(DType, DType)> func) { DType *outPtr = static_cast<DType *>(out->block()->mutable_data()); const DType *in1Ptr = static_cast<const DType *>(in1.block()->data()); const DType *in2Ptr = static_cast<const DType *>(in2.block()->data()); @@ -132,8 +133,8 @@ void TraverseBinary(const Tensor &in1, const Tensor &in2, Tensor* out, for (size_t i = 0; i < in1.Size(); i++) { outPtr[i] = func(in1Ptr[traversal_info_in1[in1.shape().size()]], in2Ptr[traversal_info_in2[in2.shape().size()]]); - traverse_next(in1, shape_multipliers_in1, traversal_info_in1, i+1); - traverse_next(in2, shape_multipliers_in2, traversal_info_in2, i+1); + traverse_next(in1, shape_multipliers_in1, traversal_info_in1, i + 1); + traverse_next(in2, shape_multipliers_in2, traversal_info_in2, i + 1); } } @@ -151,7 +152,7 @@ void Abs<float, lang::Cpp>(const Tensor& in, Tensor* out, Context *ctx) { template <> void Add<float, lang::Cpp>(const Tensor& in, const float x, Tensor* out, Context *ctx) { auto add_lambda = [&x](float a) { - return (a+x); + return (a + x); }; TraverseUnary<float>(in, out, add_lambda); } @@ -160,10 +161,10 @@ template <> void Add<float, lang::Cpp>(const Tensor& in1, const Tensor& in2, Tensor* out, Context *ctx) { // CHECK_EQ(ctx->stream, nullptr); auto add_lambda_binary = [](float a, float b) { - return (a+b); + return (a + b); }; TraverseBinary<float>(in1, in2, out, add_lambda_binary); - + } template <> @@ -171,8 +172,8 @@ void Clamp<float, lang::Cpp>(const float low, const float high, const Tensor& in, Tensor* out, Context *ctx) { auto clamp_lambda = [&low, &high](float a) { - if(a < low){return low;} - else if(a > high){return high;} + if (a < low) {return low;} + else if (a > high) {return high;} else {return a;} }; TraverseUnary<float>(in, out, clamp_lambda); @@ -189,7 +190,7 @@ void Div<float, lang::Cpp>(const float x, const Tensor& in, Tensor* out, for (size_t i = 0; i < in.Size(); i++) { CHECK_NE(inPtr[traversal_info[in.shape().size()]], 0.f); outPtr[i] = x / inPtr[traversal_info[in.shape().size()]]; - traverse_next(in, shape_multipliers, traversal_info, i+1); + traverse_next(in, shape_multipliers, traversal_info, i + 1); } } @@ -207,8 +208,8 @@ void Div<float, lang::Cpp>(const Tensor& in1, const Tensor& in2, for (size_t i = 0; i < in1.Size(); i++) { CHECK_NE(in2Ptr[traversal_info_in2[in2.shape().size()]], 0.f); outPtr[i] = in1Ptr[traversal_info_in1[in1.shape().size()]] / in2Ptr[traversal_info_in2[in2.shape().size()]]; - traverse_next(in1, shape_multipliers_in1, traversal_info_in1, i+1); - traverse_next(in2, shape_multipliers_in2, traversal_info_in2, i+1); + traverse_next(in1, shape_multipliers_in1, traversal_info_in1, i + 1); + traverse_next(in2, shape_multipliers_in2, traversal_info_in2, i + 1); } } @@ -216,16 +217,16 @@ template <> void EltwiseMult<float, lang::Cpp>(const Tensor& in, const float x, Tensor* out, Context *ctx) { auto eltwisemult_lambda = [&x](float a) { - return (a*x); + return (a * x); }; TraverseUnary<float>(in, out, eltwisemult_lambda); } template <> -void EltwiseMult<float, lang::Cpp>(const Tensor& in1, const Tensor& in2, Tensor* out, +void EltwiseMult<float, lang::Cpp>(const Tensor& in1, const Tensor& in2, Tensor* out, Context *ctx) { auto eltwisemult_lambda_binary = [](float a, float b) { - return (a*b); + return (a * b); }; TraverseBinary<float>(in1, in2, out, eltwisemult_lambda_binary); } @@ -300,7 +301,7 @@ void Log<float, lang::Cpp>(const Tensor& in, Tensor* out, for (size_t i = 0; i < in.Size(); i++) { CHECK_GT(inPtr[traversal_info[in.shape().size()]], 0.f); outPtr[i] = log(inPtr[traversal_info[in.shape().size()]]); - traverse_next(in, shape_multipliers, traversal_info, i+1); + traverse_next(in, shape_multipliers, traversal_info, i + 1); } } @@ -325,21 +326,21 @@ void LT<float, lang::Cpp>(const Tensor& in1, const Tensor& in2, Tensor* out, template <> void Pow<float, lang::Cpp>(const Tensor& in, const float x, Tensor *out, Context *ctx) { - TraverseUnary<float>(in, out, [x](float y) {return pow(y,x);}); + TraverseUnary<float>(in, out, [x](float y) {return pow(y, x);}); } template <> void Pow<float, lang::Cpp>(const Tensor& in1, const Tensor& in2, Tensor* out, Context *ctx) { auto pow_lambda_binary = [](float a, float b) { - return pow(a,b); + return pow(a, b); }; TraverseBinary<float>(in1, in2, out, pow_lambda_binary); } template <> void ReLU<float, lang::Cpp>(const Tensor& in, Tensor* out, - Context *ctx) { + Context *ctx) { auto relu_lambda = [](float a) { return (a >= 0.f) ? a : 0.f; }; @@ -355,14 +356,14 @@ void Set<float, lang::Cpp>(const float x, Tensor* out, template <> void Set<int, lang::Cpp>(const int x, Tensor* out, - Context *ctx) { + Context *ctx) { int *outPtr = static_cast<int *>(out->block()->mutable_data()); for (size_t i = 0; i < out->Size(); i++) outPtr[i] = x; } template <> void Sigmoid<float, lang::Cpp>(const Tensor& in, Tensor* out, - Context *ctx) { + Context *ctx) { auto sigmoid_lambda = [](float a) { return 1.f / (1.f + exp(-a)); }; @@ -371,7 +372,7 @@ void Sigmoid<float, lang::Cpp>(const Tensor& in, Tensor* out, template <> void Sign<float, lang::Cpp>(const Tensor& in, Tensor* out, - Context *ctx) { + Context *ctx) { auto sign_lambda = [](float a) { return (a > 0) - (a < 0); }; @@ -389,7 +390,7 @@ void Sqrt<float, lang::Cpp>(const Tensor& in, Tensor* out, for (size_t i = 0; i < in.Size(); i++) { CHECK_GE(inPtr[traversal_info[in.shape().size()]], 0.f); outPtr[i] = sqrt(inPtr[traversal_info[in.shape().size()]]); - traverse_next(in, shape_multipliers, traversal_info, i+1); + traverse_next(in, shape_multipliers, traversal_info, i + 1); } } @@ -398,7 +399,7 @@ void Sub<float, lang::Cpp>(const Tensor& in1, const Tensor& in2, Tensor* out, Context *ctx) { // CHECK_EQ(ctx->stream, nullptr); auto sub_lambda_binary = [](float a, float b) { - return (a-b); + return (a - b); }; TraverseBinary<float>(in1, in2, out, sub_lambda_binary); } @@ -418,7 +419,7 @@ void Sum<float, lang::Cpp>(const Tensor& in, float *out, template <> void Tanh<float, lang::Cpp>(const Tensor& in, Tensor* out, - Context *ctx) { + Context *ctx) { auto tanh_lambda = [](float a) { return tanh(a); }; @@ -475,7 +476,7 @@ void DGMM<float, lang::Cpp>(const bool side_right, size_t offset = r * ncol; for (size_t c = 0; c < ncol; c++) { outPtr[traversal_info[M.shape().size()]] = MPtr[traversal_info[M.shape().size()]] * vPtr[c]; - traverse_next(M, shape_multipliers, traversal_info, offset+c+1); + traverse_next(M, shape_multipliers, traversal_info, offset + c + 1); } } } else { @@ -483,7 +484,7 @@ void DGMM<float, lang::Cpp>(const bool side_right, size_t offset = r * ncol; for (size_t c = 0; c < ncol; c++) { outPtr[traversal_info[M.shape().size()]] = MPtr[traversal_info[M.shape().size()]] * vPtr[r]; - traverse_next(M, shape_multipliers, traversal_info, offset+c+1); + traverse_next(M, shape_multipliers, traversal_info, offset + c + 1); } } } @@ -509,7 +510,7 @@ template <> void Axpy<float, lang::Cpp>(const float alpha, const Tensor& in, Tensor *out, Context *ctx) { //check input tensor for strides first - if(in.strides() == out->strides()){ + if (in.strides() == out->strides()) { const float *inPtr = static_cast<const float *>(in.block()->data()); float *outPtr = static_cast<float *>(out->block()->mutable_data()); cblas_saxpy(in.Size(), alpha, inPtr, 1, outPtr, 1); @@ -522,7 +523,7 @@ template <> void Dot<float, lang::Cpp>(const Tensor& in1, const Tensor& in2, float *out, Context *ctx) { //check input tensor for strides first - if(!(in1.transpose()) && !(in2.transpose())){ + if (!(in1.transpose()) && !(in2.transpose())) { const float *in1Ptr = static_cast<const float *>(in1.block()->data()); const float *in2Ptr = static_cast<const float *>(in2.block()->data()); *out = cblas_sdot(in1.Size(), in1Ptr, 1, in2Ptr, 1); @@ -580,10 +581,10 @@ void GEMM<float, lang::Cpp>(const float alpha, const float *BPtr = static_cast<const float *>(B.block()->data()); float *CPtr = static_cast<float *>(C->block()->mutable_data()); cblas_sgemm(CblasRowMajor, transa, transb, nrowA, ncolB, ncolA, alpha, APtr, - lda, BPtr, ldb, beta, CPtr, ldc); + lda, BPtr, ldb, beta, CPtr, ldc); } -#else +#else template <> void Amax<float, lang::Cpp>(const Tensor& in, size_t *out, @@ -636,9 +637,9 @@ void Axpy<float, lang::Cpp>(const float alpha, vector<int> traversal_info = generate_traversal_info(in); vector<int> shape_multipliers = generate_shape_multipliers(in); - for (size_t i = 0; i < in.Size(); i++) { + for (size_t i = 0; i < in.Size(); i++) { outPtr[i] += alpha * inPtr[traversal_info[in.shape().size()]]; - traverse_next(in, shape_multipliers, traversal_info, i+1); + traverse_next(in, shape_multipliers, traversal_info, i + 1); } } @@ -658,7 +659,7 @@ void Dot<float, lang::Cpp>(const Tensor& in1, const Tensor& in2, // const float *in1Ptr = static_cast<const float *>(in1.data()); // const float *in2Ptr = static_cast<const float *>(in2.data()); // for (size_t i = 0; i < in.Size(); i++) { - // sum += in1Ptr[i] * in2Ptr[i]; + // sum += in1Ptr[i] * in2Ptr[i]; // } float *outPtr = static_cast<float *>(out->block()->mutable_data()); const float *in1Ptr = static_cast<const float *>(in1.block()->data()); @@ -670,8 +671,8 @@ void Dot<float, lang::Cpp>(const Tensor& in1, const Tensor& in2, for (size_t i = 0; i < in1.Size(); i++) { sum += in1Ptr[traversal_info_in1[in1.shape().size()]] * in2Ptr[traversal_info_in2[in2.shape().size()]]; - traverse_next(in1, shape_multipliers_in1, traversal_info_in1, i+1); - traverse_next(in2, shape_multipliers_in2, traversal_info_in2, i+1); + traverse_next(in1, shape_multipliers_in1, traversal_info_in1, i + 1); + traverse_next(in2, shape_multipliers_in2, traversal_info_in2, i + 1); } } @@ -697,10 +698,10 @@ void GEMV<float, lang::Cpp>(const float alpha, const Tensor& A, const Tensor& v, #endif // USE_CBLAS template <> void ComputeCrossEntropy<float, lang::Cpp>(bool int_target, - const size_t batchsize, - const size_t dim, const Block *p, - const Block *t, Block *loss, - Context *ctx) { + const size_t batchsize, + const size_t dim, const Block *p, + const Block *t, Block *loss, + Context *ctx) { const float *pPtr = static_cast<const float *>(p->data()); const int *tPtr = static_cast<const int *>(t->data()); float *lossPtr = static_cast<float *>(loss->mutable_data()); @@ -712,7 +713,7 @@ void ComputeCrossEntropy<float, lang::Cpp>(bool int_target, lossPtr[i] = -std::log((std::max)(prob_of_truth, FLT_MIN)); } } else { - for (size_t i = 0;i < batchsize; i++) { + for (size_t i = 0; i < batchsize; i++) { float sum = 0.f; for (size_t j = 0; j < dim; j++) { sum += tPtr[i * dim + j]; @@ -728,10 +729,10 @@ void ComputeCrossEntropy<float, lang::Cpp>(bool int_target, template <> void SoftmaxCrossEntropyBwd<float, lang::Cpp>(bool int_target, - const size_t batchsize, - const size_t dim, const Block *p, - const Block *t, Block *grad, - Context *ctx) { + const size_t batchsize, + const size_t dim, const Block *p, + const Block *t, Block *grad, + Context *ctx) { CHECK_EQ(p, grad) << "Use the same pointer to optimize performance"; // const float* pPtr = static_cast<const float*>(p->data()); const int *tPtr = static_cast<const int *>(t->data()); @@ -764,13 +765,13 @@ void RowMax<float, lang::Cpp>(const Tensor& in, Tensor *out, Context *ctx) { const size_t ncol = in.shape()[1]; vector<int> traversal_info = generate_traversal_info(in); vector<int> shape_multipliers = generate_shape_multipliers(in); - + for (size_t r = 0; r < nrow; r++) { int counter_offset = (r * ncol); float maxval = 0; - for (size_t c = 0; c < ncol; c++){ + for (size_t c = 0; c < ncol; c++) { maxval = (std::max)(maxval, inPtr[traversal_info[in.shape().size()]]); - traverse_next(in, shape_multipliers, traversal_info, counter_offset+c+1); + traverse_next(in, shape_multipliers, traversal_info, counter_offset + c + 1); } outPtr[r] = maxval; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3e2b75cb/src/core/tensor/tensor_math_cuda.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h index 6e86ca7..55d6a1b 100644 --- a/src/core/tensor/tensor_math_cuda.h +++ b/src/core/tensor/tensor_math_cuda.h @@ -34,45 +34,45 @@ namespace singa { // ===================== Helper Functions ============================= - /* - cudnn requires tensor dimensions to fulfill 1 requirement: - 1.) Dimensions to be set to a minimum of 4 for 4d and lower dimensional tensors - if input tensor is 5d, cudnn will take a 5d tensor as input. Beyond 5d, certain operations are not supported. - (cudnnOp supports up to 5d, cudnnReduce supports up to 8d) - - for e.g. Tensor A has shape {3,3}, cudnn requires shape of {1,1,3,3} to be the input - Tensor B has shape (2,3,4), cudnn requires shape of {1,2,3,4} to be the input - */ - vector<int> generate_shape_cuda(const Tensor& x) { - Shape shape_ = x.shape(); - vector<int> shape_arr; - if(shape_.size() <= 4){ - for (size_t n=0; n<4-shape_.size(); ++n) { - shape_arr.push_back(1); - } - for (size_t n=0; n<shape_.size(); ++n) { - shape_arr.push_back(shape_.at(n)); - } - return shape_arr; - } else if(shape_.size() == 5){ - for (size_t n=0; n<shape_.size(); ++n) { - shape_arr.push_back(shape_.at(n)); - } - return shape_arr; - } else { - LOG(FATAL) << "Dimensions (shape) beyond 5 are currently not supported" ; +/* +cudnn requires tensor dimensions to fulfill 1 requirement: + 1.) Dimensions to be set to a minimum of 4 for 4d and lower dimensional tensors + if input tensor is 5d, cudnn will take a 5d tensor as input. Beyond 5d, certain operations are not supported. + (cudnnOp supports up to 5d, cudnnReduce supports up to 8d) + + for e.g. Tensor A has shape {3,3}, cudnn requires shape of {1,1,3,3} to be the input + Tensor B has shape (2,3,4), cudnn requires shape of {1,2,3,4} to be the input +*/ +vector<int> generate_shape_cuda(const Tensor& x) { + Shape shape_ = x.shape(); + vector<int> shape_arr; + if (shape_.size() <= 4) { + for (size_t n = 0; n < 4 - shape_.size(); ++n) { + shape_arr.push_back(1); + } + for (size_t n = 0; n < shape_.size(); ++n) { + shape_arr.push_back(shape_.at(n)); } + return shape_arr; + } else if (shape_.size() == 5) { + for (size_t n = 0; n < shape_.size(); ++n) { + shape_arr.push_back(shape_.at(n)); + } + return shape_arr; + } else { + LOG(FATAL) << "Dimensions (shape) beyond 5 are currently not supported" ; } +} - int generate_dim_cuda(const Tensor& x) { - if(x.shape().size() <= 4){return 4;} - else if(x.shape().size() == 5){return 5;} - else{ - LOG(FATAL) << "Dimensions (shape) beyond 5 are currently not supported" ; - } +int generate_dim_cuda(const Tensor& x) { + if (x.shape().size() <= 4) {return 4;} + else if (x.shape().size() == 5) {return 5;} + else { + LOG(FATAL) << "Dimensions (shape) beyond 5 are currently not supported" ; } +} -/* +/* cudnn requires stride dimensions to conform to the format of the shape input as well 1.) Stride dimensions to be set to a minimum of 4 for 4d and lower dimensional tensors If input tensor is 5d, cudnn will take a 5d tensor as input. Beyond 5d, certain operations are not supported. @@ -81,51 +81,51 @@ namespace singa { for e.g. Tensor A has shape {3,3}, stride {3,1}, cudnn requires shape {1,1,3,3} and stride {9, 9, 3, 1} or {9, 9, 1, 3} to be the inputs */ - vector<int> generate_strides_cuda(const Tensor& x) { - Shape shape_ = x.shape(); - vector<int> strides_ = x.strides(); - vector<int> strides_arr; - int product = 1; - for (size_t n=0; n<(shape_.size()); ++n) { - product *= shape_[n]; +vector<int> generate_strides_cuda(const Tensor& x) { + Shape shape_ = x.shape(); + vector<int> strides_ = x.strides(); + vector<int> strides_arr; + int product = 1; + for (size_t n = 0; n < (shape_.size()); ++n) { + product *= shape_[n]; + } + if (shape_.size() <= 4) { + for (size_t n = 0; n < 4 - shape_.size(); ++n) { + strides_arr.push_back(product); + } + for (size_t n = 0; n < strides_.size(); ++n) { + strides_arr.push_back(strides_[n]); } - if(shape_.size() <= 4){ - for (size_t n=0; n<4-shape_.size(); ++n) { - strides_arr.push_back(product); - } - for (size_t n=0; n<strides_.size(); ++n) { - strides_arr.push_back(strides_[n]); - } - return strides_arr; - } else if(shape_.size() == 5){ - for (size_t n=0; n<strides_.size(); ++n) { - strides_arr.push_back(strides_[n]); - } - return strides_arr; - } else { - LOG(FATAL) << "Dimensions (strides) beyond 5 are currently not supported" ; + return strides_arr; + } else if (shape_.size() == 5) { + for (size_t n = 0; n < strides_.size(); ++n) { + strides_arr.push_back(strides_[n]); } + return strides_arr; + } else { + LOG(FATAL) << "Dimensions (strides) beyond 5 are currently not supported" ; } +} -cudnnTensorDescriptor_t generate_tensorND_desc(const Tensor& x){ +cudnnTensorDescriptor_t generate_tensorND_desc(const Tensor& x) { cudnnTensorDescriptor_t x_desc; cudnnCreateTensorDescriptor(&x_desc); cudnnSetTensorNdDescriptor(x_desc, CUDNN_DATA_FLOAT, generate_dim_cuda(x), generate_shape_cuda(x).data(), generate_strides_cuda(x).data() - ); + ); return x_desc; } -cudnnOpTensorDescriptor_t generate_Op_desc(cudnnOpTensorOp_t op){ +cudnnOpTensorDescriptor_t generate_Op_desc(cudnnOpTensorOp_t op) { cudnnOpTensorDescriptor_t op_desc; cudnnCreateOpTensorDescriptor(&op_desc); cudnnSetOpTensorDescriptor(op_desc, op, CUDNN_DATA_FLOAT, CUDNN_PROPAGATE_NAN - ); + ); return op_desc; } @@ -144,10 +144,10 @@ void Abs<float, lang::Cuda>(const Tensor& in, Tensor* out, float beta = 0.0; cudnnTensorDescriptor_t in_desc = generate_tensorND_desc(in); cudnnOpTensor(ctx->cudnn_handle, generate_Op_desc(CUDNN_OP_TENSOR_MAX), - (void*)(&alpha1), in_desc, inPtr, + (void*)(&alpha1), in_desc, inPtr, (void*)(&alpha2), in_desc, inPtr, (void*)(&beta), generate_tensorND_desc(*out), outPtr - ); + ); cudnnDestroyTensorDescriptor(in_desc); } @@ -156,8 +156,8 @@ void Set<float, lang::Cuda>(const float x, Tensor* out, Context* ctx) { float* outPtr = static_cast<float*>(out->block()->mutable_data()); - cudnnSetTensor(ctx->cudnn_handle, generate_tensorND_desc(*out), - outPtr, (void*)(&x)); + cudnnSetTensor(ctx->cudnn_handle, generate_tensorND_desc(*out), + outPtr, (void*)(&x)); } template <> @@ -171,7 +171,7 @@ void Add<float, lang::Cuda>(const Tensor& in, const float x, cudnnAddTensor(ctx->cudnn_handle, (void*)(&alpha), generate_tensorND_desc(in), inPtr, (void*)(&beta), generate_tensorND_desc(*out), outPtr - ); + ); } /// out = in1 + in2 @@ -186,18 +186,18 @@ void Add<float, lang::Cuda>(const Tensor& in1, float alpha2 = 1.0; float beta = 0.0; - if((in1.nDim() == in2.nDim()) || (in2.nDim() == 1)){ + if ((in1.nDim() == in2.nDim()) || (in2.nDim() == 1)) { cudnnOpTensor(ctx->cudnn_handle, generate_Op_desc(CUDNN_OP_TENSOR_ADD), - (void*)(&alpha1), generate_tensorND_desc(in1), inPtr1, - (void*)(&alpha2), generate_tensorND_desc(in2), inPtr2, - (void*)(&beta), generate_tensorND_desc(*out), outPtr - ); + (void*)(&alpha1), generate_tensorND_desc(in1), inPtr1, + (void*)(&alpha2), generate_tensorND_desc(in2), inPtr2, + (void*)(&beta), generate_tensorND_desc(*out), outPtr + ); } else { cudnnOpTensor(ctx->cudnn_handle, generate_Op_desc(CUDNN_OP_TENSOR_ADD), - (void*)(&alpha1), generate_tensorND_desc(in1), inPtr1, - (void*)(&alpha2), generate_tensorND_desc(in1), inPtr2, - (void*)(&beta), generate_tensorND_desc(*out), outPtr - ); + (void*)(&alpha1), generate_tensorND_desc(in1), inPtr1, + (void*)(&alpha2), generate_tensorND_desc(in1), inPtr2, + (void*)(&beta), generate_tensorND_desc(*out), outPtr + ); } } @@ -213,18 +213,18 @@ void Sub<float, lang::Cuda>(const Tensor& in1, float alpha2 = -1.0; float beta = 0.0; - if((in1.nDim() == in2.nDim()) || (in2.nDim() == 1)){ + if ((in1.nDim() == in2.nDim()) || (in2.nDim() == 1)) { cudnnOpTensor(ctx->cudnn_handle, generate_Op_desc(CUDNN_OP_TENSOR_ADD), - (void*)(&alpha1), generate_tensorND_desc(in1), inPtr1, - (void*)(&alpha2), generate_tensorND_desc(in2), inPtr2, - (void*)(&beta), generate_tensorND_desc(*out), outPtr - ); + (void*)(&alpha1), generate_tensorND_desc(in1), inPtr1, + (void*)(&alpha2), generate_tensorND_desc(in2), inPtr2, + (void*)(&beta), generate_tensorND_desc(*out), outPtr + ); } else { cudnnOpTensor(ctx->cudnn_handle, generate_Op_desc(CUDNN_OP_TENSOR_ADD), - (void*)(&alpha1), generate_tensorND_desc(in1), inPtr1, - (void*)(&alpha2), generate_tensorND_desc(in1), inPtr2, - (void*)(&beta), generate_tensorND_desc(*out), outPtr - ); + (void*)(&alpha1), generate_tensorND_desc(in1), inPtr1, + (void*)(&alpha2), generate_tensorND_desc(in1), inPtr2, + (void*)(&beta), generate_tensorND_desc(*out), outPtr + ); } } @@ -250,17 +250,17 @@ void Div<float, lang::Cuda>(const Tensor& in1, const size_t num = in1.Size(); //if both in1 and in2 strides are the same, we proceed to normal cuda::div - if(in1.strides() == in2.strides()){ - cuda::div(num, inPtr1, inPtr2, outPtr, ctx->stream); - out->set_strides(in1.strides()); + if (in1.strides() == in2.strides()) { + cuda::div(num, inPtr1, inPtr2, outPtr, ctx->stream); + out->set_strides(in1.strides()); } else { //else we transform in1 to out to store first float alpha = 1.0; float beta = 0.0; out->set_strides(in2.strides()); cudnnTransformTensor(ctx->cudnn_handle, - (void*)(&alpha), generate_tensorND_desc(in1), inPtr1, - (void*)(&beta), generate_tensorND_desc(*out), outPtr + (void*)(&alpha), generate_tensorND_desc(in1), inPtr1, + (void*)(&beta), generate_tensorND_desc(*out), outPtr ); cuda::div(num, outPtr, inPtr2, outPtr, ctx->stream); @@ -286,8 +286,8 @@ void EltwiseMult<float, lang::Cuda>(const Tensor& in, float alpha = x, beta = 0.0; cudnnAddTensor(ctx->cudnn_handle, - (void*)(&alpha), generate_tensorND_desc(in), inPtr, - (void*)(&beta), generate_tensorND_desc(*out), outPtr + (void*)(&alpha), generate_tensorND_desc(in), inPtr, + (void*)(&beta), generate_tensorND_desc(*out), outPtr ); } @@ -302,17 +302,17 @@ void EltwiseMult<float, lang::Cuda>(const Tensor& in1, const size_t num = in1.Size(); //if both in1 and in2 strides are the same, we proceed to normal cuda::mult - if(in1.strides() == in2.strides()){ - cuda::mult(num, inPtr1, inPtr2, outPtr, ctx->stream); - out->set_strides(in1.strides()); + if (in1.strides() == in2.strides()) { + cuda::mult(num, inPtr1, inPtr2, outPtr, ctx->stream); + out->set_strides(in1.strides()); } else { //else we transform in1 to out to store first float alpha = 1.0; float beta = 0.0; out->set_strides(in2.strides()); cudnnTransformTensor(ctx->cudnn_handle, - (void*)(&alpha), generate_tensorND_desc(in1), inPtr1, - (void*)(&beta), generate_tensorND_desc(*out), outPtr + (void*)(&alpha), generate_tensorND_desc(in1), inPtr1, + (void*)(&beta), generate_tensorND_desc(*out), outPtr ); cuda::mult(num, outPtr, inPtr2, outPtr, ctx->stream); @@ -443,17 +443,17 @@ void Pow<float, lang::Cuda>(const Tensor& in1, float* outPtr = static_cast<float*>(out->block()->mutable_data()); const size_t num = in1.Size(); - if(in1.strides() == in2.strides()){ - cuda::pow(num, inPtr1, inPtr2, outPtr, ctx->stream); - out->set_strides(in1.strides()); + if (in1.strides() == in2.strides()) { + cuda::pow(num, inPtr1, inPtr2, outPtr, ctx->stream); + out->set_strides(in1.strides()); } else { //else we transform in1 to out to store first float alpha = 1.0; float beta = 0.0; out->set_strides(in2.strides()); cudnnTransformTensor(ctx->cudnn_handle, - (void*)(&alpha), generate_tensorND_desc(in1), inPtr1, - (void*)(&beta), generate_tensorND_desc(*out), outPtr + (void*)(&alpha), generate_tensorND_desc(in1), inPtr1, + (void*)(&beta), generate_tensorND_desc(*out), outPtr ); cuda::pow(num, outPtr, inPtr2, outPtr, ctx->stream); @@ -473,18 +473,18 @@ void Pow<float, lang::Cuda>(const Tensor& in1, // double coef = 0.0; //only used for CLIPPED_RELU or ELU // cudnnCreateActivationDescriptor(&act_desc); // cudnnSetActivationDescriptor(act_desc, mode, cudnn_propagation, coef); - + // float alpha[1] = {1.0}; // float beta[1] = {0.0}; // cudnnDataType_t cudnn_dtype = CUDNN_DATA_FLOAT; // cudnnTensorDescriptor_t in_desc, out_desc; // cudnnCreateTensorDescriptor(&in_desc); // cudnnCreateTensorDescriptor(&out_desc); -// cudnnSetTensorNdDescriptor(in_desc, cudnn_dtype, in.generate_dim_cuda(), +// cudnnSetTensorNdDescriptor(in_desc, cudnn_dtype, in.generate_dim_cuda(), // in.generate_shape_cuda().data(), in.generate_strides_cuda().data()); -// cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), +// cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), // out->generate_shape_cuda().data(), out->generate_strides_cuda().data()); -// cudnnActivationForward(ctx->cudnn_handle, act_desc, (void*)(&alpha), in_desc, inPtr, +// cudnnActivationForward(ctx->cudnn_handle, act_desc, (void*)(&alpha), in_desc, inPtr, // (void*)(&beta), out_desc, outPtr); // cudnnDestroyTensorDescriptor(in_desc); @@ -515,18 +515,18 @@ void ReLU<float, lang::Cuda>(const Tensor& in, Tensor* out, // double coef = 0.0; //only used for CLIPPED_RELU or ELU // cudnnCreateActivationDescriptor(&act_desc); // cudnnSetActivationDescriptor(act_desc, mode, cudnn_propagation, coef); - + // float alpha[1] = {1.0}; // float beta[1] = {0.0}; // cudnnDataType_t cudnn_dtype = CUDNN_DATA_FLOAT; // cudnnTensorDescriptor_t in_desc, out_desc; // cudnnCreateTensorDescriptor(&in_desc); // cudnnCreateTensorDescriptor(&out_desc); -// cudnnSetTensorNdDescriptor(in_desc, cudnn_dtype, in.generate_dim_cuda(), +// cudnnSetTensorNdDescriptor(in_desc, cudnn_dtype, in.generate_dim_cuda(), // in.generate_shape_cuda().data(), in.generate_strides_cuda().data()); -// cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), +// cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), // out->generate_shape_cuda().data(), out->generate_strides_cuda().data()); -// cudnnActivationForward(ctx->cudnn_handle, act_desc, (void*)(&alpha), in_desc, inPtr, +// cudnnActivationForward(ctx->cudnn_handle, act_desc, (void*)(&alpha), in_desc, inPtr, // (void*)(&beta), out_desc, outPtr); // cudnnDestroyTensorDescriptor(in_desc); @@ -562,16 +562,16 @@ void Sqrt<float, lang::Cuda>(const Tensor& in, Tensor* out, Context* ctx) { const float* inPtr = static_cast<const float*>(in.block()->data()); float* outPtr = static_cast<float*>(out->block()->mutable_data()); - + float alpha1 = 1.0; float alpha2 = 0.0; float beta = 0.0; cudnnTensorDescriptor_t in_desc = generate_tensorND_desc(in); cudnnOpTensor(ctx->cudnn_handle, generate_Op_desc(CUDNN_OP_TENSOR_SQRT), - (void*)(&alpha1), in_desc, inPtr, + (void*)(&alpha1), in_desc, inPtr, (void*)(&alpha2), in_desc, inPtr, (void*)(&beta), generate_tensorND_desc(*out), outPtr - ); + ); } /// Element-wise operation, out[i]=in[i]^2 @@ -598,15 +598,15 @@ void Sum<float, lang::Cuda>(const Tensor& in, float* out, Context* ctx) { const float* inPtr = static_cast<const float*>(in.block()->data()); - //reduce all axes to 1 for cudnnReduce, e.g. Tensor A with shape (2,4) will be reduced to (1) - Shape reduced_shape = {1}; - Tensor t(reduced_shape, in.device(), in.data_type()); - float* tPtr = static_cast<float*>(t.block()->mutable_data()); - vector<int> reduce_all_axes = generate_shape_cuda(in); - for (size_t n=0; n<reduce_all_axes.size(); ++n) { + //reduce all axes to 1 for cudnnReduce, e.g. Tensor A with shape (2,4) will be reduced to (1) + Shape reduced_shape = {1}; + Tensor t(reduced_shape, in.device(), in.data_type()); + float* tPtr = static_cast<float*>(t.block()->mutable_data()); + vector<int> reduce_all_axes = generate_shape_cuda(in); + for (size_t n = 0; n < reduce_all_axes.size(); ++n) { reduce_all_axes[n] = 1; - } - + } + //reduce_desc cudnnReduceTensorDescriptor_t reduce_desc; cudnnReduceTensorOp_t reduce_op = CUDNN_REDUCE_TENSOR_ADD; @@ -620,11 +620,11 @@ void Sum<float, lang::Cuda>(const Tensor& in, float* out, //instantiate 2 new tensors to use new blocks as memory instead of cudaMalloc size_t reduction_size_int = Product(in.shape()); - Shape reduction_size = {reduction_size_int*100}; + Shape reduction_size = {reduction_size_int * 100}; Tensor indices(reduction_size, in.device(), in.data_type()); Tensor workspace(reduction_size, in.device(), in.data_type()); - size_t indices_bytes = indices.block()->size()*100; - size_t workspace_bytes = workspace.block()->size()*100; + size_t indices_bytes = indices.block()->size() * 100; + size_t workspace_bytes = workspace.block()->size() * 100; size_t* indicesPtr = static_cast<size_t*>(indices.block()->mutable_data()); float* workspacePtr = static_cast<float*>(workspace.block()->mutable_data()); //void* indicesPtr{nullptr}; void* workspacePtr{nullptr}; @@ -636,7 +636,7 @@ void Sum<float, lang::Cuda>(const Tensor& in, float* out, indicesPtr, indices_bytes, workspacePtr, workspace_bytes, (void*)(&alpha), generate_tensorND_desc(in), inPtr, (void*)(&beta), generate_tensorND_desc(t), tPtr - ); + ); *out = tPtr[0]; } @@ -655,18 +655,18 @@ void Sum<float, lang::Cuda>(const Tensor& in, float* out, // double coef = 0.0; //only used for CLIPPED_RELU or ELU // cudnnCreateActivationDescriptor(&act_desc); // cudnnSetActivationDescriptor(act_desc, mode, cudnn_propagation, coef); - + // float alpha[1] = {1.0}; // float beta[1] = {0.0}; // cudnnDataType_t cudnn_dtype = CUDNN_DATA_FLOAT; // cudnnTensorDescriptor_t in_desc, out_desc; // cudnnCreateTensorDescriptor(&in_desc); // cudnnCreateTensorDescriptor(&out_desc); -// cudnnSetTensorNdDescriptor(in_desc, cudnn_dtype, in.generate_dim_cuda(), +// cudnnSetTensorNdDescriptor(in_desc, cudnn_dtype, in.generate_dim_cuda(), // in.generate_shape_cuda().data(), in.generate_strides_cuda().data()); -// cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), +// cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), // out->generate_shape_cuda().data(), out->generate_strides_cuda().data()); -// cudnnActivationForward(ctx->cudnn_handle, act_desc, (void*)(&alpha), in_desc, inPtr, +// cudnnActivationForward(ctx->cudnn_handle, act_desc, (void*)(&alpha), in_desc, inPtr, // (void*)(&beta), out_desc, outPtr); // cudnnDestroyTensorDescriptor(in_desc); @@ -676,7 +676,7 @@ void Sum<float, lang::Cuda>(const Tensor& in, float* out, template <> void Tanh<float, lang::Cuda>(const Tensor& in, Tensor* out, - Context* ctx) { + Context* ctx) { const float* inPtr = static_cast<const float*>(in.block()->data()); float* outPtr = static_cast<float*>(out->block()->mutable_data()); const size_t num = in.Size(); @@ -856,22 +856,22 @@ void GEMM<float, lang::Cuda>(const float alpha, template <> void ComputeCrossEntropy<float, lang::Cuda>(bool int_target, - const size_t batchsize, - const size_t dim, const Block* p, - const Block* t, Block* loss, - Context* ctx) { + const size_t batchsize, + const size_t dim, const Block* p, + const Block* t, Block* loss, + Context* ctx) { const float* pPtr = static_cast<const float*>(p->data()); const int* tPtr = static_cast<const int*>(t->data()); float* lossPtr = static_cast<float*>(loss->mutable_data()); cuda::ComputeCrossEntropy(int_target, batchsize, dim, pPtr, tPtr, lossPtr, - ctx->stream); + ctx->stream); } template <> void SoftmaxCrossEntropyBwd<float, lang::Cuda>(bool int_target, - const size_t batchsize, - const size_t dim, const Block* p, - const Block* t, Block* grad, - Context* ctx) { + const size_t batchsize, + const size_t dim, const Block* p, + const Block* t, Block* grad, + Context* ctx) { CHECK_EQ(p, grad) << "Use the same pointer to optimize performance"; const float* pPtr = static_cast<const float*>(p->data()); const int* tPtr = static_cast<const int*>(t->data()); @@ -924,11 +924,11 @@ void SoftmaxCrossEntropyBwd<float, lang::Cuda>(bool int_target, // cudnnTensorDescriptor_t in_desc, out_desc; // cudnnCreateTensorDescriptor(&in_desc); // cudnnCreateTensorDescriptor(&out_desc); -// cudnnSetTensorNdDescriptor(in_desc, cudnn_dtype, in.generate_dim_cuda(), +// cudnnSetTensorNdDescriptor(in_desc, cudnn_dtype, in.generate_dim_cuda(), // in.generate_shape_cuda().data(), in.generate_strides_cuda().data()); -// //cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), +// //cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), // out->generate_shape_cuda().data(), out->generate_strides_cuda().data()); -// cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), +// cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), // reduce_row_axes_shape.data(), reduced_strides.data()); // cudnnReduceTensor(ctx->cudnn_handle, reduce_desc, // indicesPtr, indices_bytes, workspacePtr, workspace_bytes, @@ -946,7 +946,7 @@ void RowMax<float, lang::Cuda>(const Tensor& in, Tensor* out, const size_t nrow = in.shape()[0]; const size_t ncol = in.shape()[1]; - if(in.transpose()){ + if (in.transpose()) { Tensor t(in.shape(), in.device(), in.data_type()); float* tPtr = static_cast<float*>(t.block()->mutable_data()); @@ -954,8 +954,8 @@ void RowMax<float, lang::Cuda>(const Tensor& in, Tensor* out, float beta = 0.0; cudnnTransformTensor(ctx->cudnn_handle, - (void*)(&alpha), generate_tensorND_desc(in), inPtr, - (void*)(&beta), generate_tensorND_desc(t), tPtr + (void*)(&alpha), generate_tensorND_desc(in), inPtr, + (void*)(&beta), generate_tensorND_desc(t), tPtr ); const float* tPtr_const = static_cast<const float*>(t.block()->data());
