SINGA-196 Rename class Blob to Block Rename Blob (blob) into Block (block). Block represents a block of memory.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/9c2869b9 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/9c2869b9 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/9c2869b9 Branch: refs/heads/master Commit: 9c2869b9ab5da4affa294b4b23c88aec0b226984 Parents: 272100a Author: Wei Wang <[email protected]> Authored: Mon Jun 13 19:15:32 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Mon Jun 13 19:15:32 2016 +0800 ---------------------------------------------------------------------- include/singa/core/common.h | 6 +- include/singa/core/device.h | 18 ++-- include/singa/core/tensor.h | 32 +++--- include/singa/model/layer.h | 2 +- src/core/device/device.cc | 24 ++--- src/core/tensor/tensor.cc | 172 ++++++++++++++++-------------- src/core/tensor/tensor_math.h | 118 ++++++++++---------- src/core/tensor/tensor_math_cpp.h | 151 +++++++++++++------------- src/core/tensor/tensor_math_cuda.h | 135 +++++++++++------------ src/model/layer/cudnn_activation.cc | 26 ++--- src/model/layer/cudnn_batchnorm.cc | 126 +++++++++++----------- src/model/layer/cudnn_convolution.cc | 104 ++++++++---------- src/model/layer/cudnn_dropout.cc | 46 ++++---- src/model/layer/cudnn_lrn.cc | 78 +++++--------- src/model/layer/cudnn_pooling.cc | 42 ++++---- src/model/layer/cudnn_softmax.cc | 22 ++-- test/singa/test_cpp_cpu.cc | 16 +-- test/singa/test_tensor.cc | 14 +-- 18 files changed, 548 insertions(+), 584 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9c2869b9/include/singa/core/common.h ---------------------------------------------------------------------- diff --git a/include/singa/core/common.h b/include/singa/core/common.h index e6f4c90..b556750 100644 --- a/include/singa/core/common.h +++ b/include/singa/core/common.h @@ -42,10 +42,10 @@ typedef struct _Cuda { } Cuda; typedef struct _Opencl { } Opencl; } // namespace lang -/// Blob represent a chunk of memory (on device or host) managed by VirtualMemory. -class Blob { +/// Block represent a chunk of memory (on device or host). +class Block { public: - Blob(void* ptr, size_t size) : data_(ptr), size_(size), ref_count_(1) {} + Block(void* ptr, size_t size) : data_(ptr), size_(size), ref_count_(1) {} void* mutable_data() const { return data_; } const void* data() const { return data_; } size_t size() const { return size_; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9c2869b9/include/singa/core/device.h ---------------------------------------------------------------------- diff --git a/include/singa/core/device.h b/include/singa/core/device.h index 56eda70..f69e4c6 100644 --- a/include/singa/core/device.h +++ b/include/singa/core/device.h @@ -47,21 +47,21 @@ class Device { virtual void SetRandSeed(unsigned seed) = 0; /// Called by Tensor. - Blob* NewBlob(int size); + Block* NewBlock(int size); /// Called by Tensor. - void FreeBlob(Blob* blob); + void FreeBlock(Block* block); /// Copy data within or across devices. - void CopyDataToFrom(Blob* dst, Blob* src, size_t nBytes, + void CopyDataToFrom(Block* dst, Block* src, size_t nBytes, CopyDirection direction, int dst_offset, int src_offset); - void CopyDataFromHostPtr(Blob* dst, const void* src, size_t nBytes, + void CopyDataFromHostPtr(Block* dst, const void* src, size_t nBytes, size_t dst_offset = 0); /// Submit the operation to the device, which may execute it right now or /// delay it depending on the scheduler. - void Exec(function<void(Context*)>&& fn, const vector<Blob*> read_blobs, - const vector<Blob*> write_blobs, + void Exec(function<void(Context*)>&& fn, const vector<Block*> read_blocks, + const vector<Block*> write_blocks, bool use_rand_generator = false); // Wait for one event. @@ -205,11 +205,11 @@ class CallbackArg { /// Type of callback functions for executing tensor ops. typedef function<void(CallbackArg*)> CallbackFn; public: - /// Operation has a function, and read/write blobs. + /// Operation has a function, and read/write blocks. typedef struct _Operation { function<void(Context*)> fn; - const vector<Blob*> read_blobs; - const vector<Blob*> write_blobs; + const vector<Block*> read_blocks; + const vector<Block*> write_blocks; } Operation; */ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9c2869b9/include/singa/core/tensor.h ---------------------------------------------------------------------- diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h index 8cfa705..48a8c8f 100644 --- a/include/singa/core/tensor.h +++ b/include/singa/core/tensor.h @@ -64,17 +64,17 @@ class Tensor { /// Copy Tensor to share the internal data. No deep copy. Tensor(Tensor &&from); - /// For functions in xx_math.cc to access the blob. - /// Users should not operate against Blob directly. - /// blob_ is allocated in constructors. - Blob *blob() const { return blob_; } + /// For functions in xx_math.cc to access the block. + /// Users should not operate against Block directly. + /// block_ is allocated in constructors. + Block *block() const { return block_; } Device *device() const { return device_; } /// return immutable Tensor values with given type. template <typename SType> SType data() const { - return static_cast<SType>(blob()->data()); + return static_cast<SType>(block()->data()); } /// data type, including kFloat16, kFloat32, kInt @@ -93,23 +93,23 @@ class Tensor { /// return number of total elements size_t Size() const { - CHECK_EQ(blob_->size() % SizeOf(data_type_), 0u); - return blob_->size() / SizeOf(data_type_); + CHECK_EQ(block_->size() % SizeOf(data_type_), 0u); + return block_->size() / SizeOf(data_type_); } /// return memory size (i.e., Bytes) - size_t MemSize() const { return blob_->size(); } + size_t MemSize() const { return block_->size(); } - /// Reset the tensor shape, it may reallocate blob, if MemSize() changes. + /// Reset the tensor shape, it may reallocate block, if MemSize() changes. void Reshape(const Shape &shape); void Reshape(Shape &&shape); /// Reset the shape, device, and data type as given tensor. - /// If blob size changes, then reallocate a new blob. The previous blob would + /// If block size changes, then reallocate a new block. The previous block would /// be deleted. void ResetLike(const Tensor &t); - /// Reset the data type, it would reallocate blob if type changes. + /// Reset the data type, it would reallocate block if type changes. void AsType(const DataType type); /// Reset the device. @@ -140,10 +140,10 @@ class Tensor { /// No data copy, just set the transpose_ filed of the returned tensor. Tensor T() const; - /// Copy the meta info with data blob shared. + /// Copy the meta info with data block shared. Tensor &operator=(const Tensor &in); - /// Copy the meta info with data blob shared. + /// Copy the meta info with data block shared. Tensor &operator=(Tensor &&in); Tensor &operator+=(const Tensor &in); @@ -179,9 +179,9 @@ class Tensor { bool transpose_ = false; DataType data_type_ = kFloat32; Device *device_ = nullptr; - /// Note: blob_ is allocated in lazy manner to avoid frequent malloc/free. - /// If you want to get an allocated Blob, use blob() instead of blob_. - Blob *blob_ = nullptr; + /// Note: block_ is allocated in lazy manner to avoid frequent malloc/free. + /// If you want to get an allocated Block, use block() instead of block_. + Block *block_ = nullptr; Shape shape_ = {}; }; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9c2869b9/include/singa/model/layer.h ---------------------------------------------------------------------- diff --git a/include/singa/model/layer.h b/include/singa/model/layer.h index 82c8edc..2addc98 100644 --- a/include/singa/model/layer.h +++ b/include/singa/model/layer.h @@ -61,7 +61,7 @@ class Layer { virtual void Setup(const LayerConf& conf) { name_ = conf.name(); // for (const auto& spec : conf.param()) param_specs_.push_back(spec); - // TODO(wangwei) load param values from checkpoint blobs. + // TODO(wangwei) load param values from checkpoint files. } /// Do feature transformation for the given 'input' tensor (denoted as x). http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9c2869b9/src/core/device/device.cc ---------------------------------------------------------------------- diff --git a/src/core/device/device.cc b/src/core/device/device.cc index 1d3c446..36381e4 100644 --- a/src/core/device/device.cc +++ b/src/core/device/device.cc @@ -25,31 +25,31 @@ Device::Device(int id, int num_executors, string scheduler, string vm) host_ = &defaultDevice; } -void Device::Exec(function<void(Context*)>&& fn, const vector<Blob*> read_blobs, - const vector<Blob*> write_blobs, bool use_rand_generator) { +void Device::Exec(function<void(Context*)>&& fn, const vector<Block*> read_blocks, + const vector<Block*> write_blocks, bool use_rand_generator) { // TODO(wangwei) execute operations scheduled by the scheduler. DoExec(std::move(fn), 0); } -// TODO(wangwei) get Blob from the memory manager -Blob* Device::NewBlob(int size) { +// TODO(wangwei) get Block from the memory manager +Block* Device::NewBlock(int size) { if (size > 0) { void* ptr = Malloc(size); - return new Blob(ptr, size); + return new Block(ptr, size); } else { return nullptr; } } -// TODO(wangwei) return Blob to the memory manager -void Device::FreeBlob(Blob* blob) { - if (blob != nullptr) { - Free(blob->mutable_data()); - delete blob; +// TODO(wangwei) return Block to the memory manager +void Device::FreeBlock(Block* block) { + if (block != nullptr) { + Free(block->mutable_data()); + delete block; } } -void Device::CopyDataToFrom(Blob* dst, Blob* src, size_t nBytes, +void Device::CopyDataToFrom(Block* dst, Block* src, size_t nBytes, CopyDirection direct, int dst_offset, int src_offset) { this->Exec( @@ -62,7 +62,7 @@ void Device::CopyDataToFrom(Blob* dst, Blob* src, size_t nBytes, {src}, {dst}); } -void Device::CopyDataFromHostPtr(Blob* dst, const void* src, size_t nBytes, +void Device::CopyDataFromHostPtr(Block* dst, const void* src, size_t nBytes, size_t dst_offset) { auto direct = lang_ == kCpp ? kHostToHost : kHostToDevice; void* dstptr = reinterpret_cast<char*>(dst->mutable_data()) + dst_offset; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9c2869b9/src/core/tensor/tensor.cc ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index 4e0d94b..8afc17c 100644 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -26,8 +26,9 @@ namespace singa { Tensor::~Tensor() { // LOG(ERROR) << "~"; - if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_); - blob_ = nullptr; + if (block_ != nullptr && block_->DecRefCount() == 0) + device_->FreeBlock(block_); + block_ = nullptr; } Tensor::Tensor() { device_ = &defaultDevice; } @@ -35,28 +36,28 @@ Tensor::Tensor() { device_ = &defaultDevice; } Tensor::Tensor(const Shape &shape, const DataType dtype) : data_type_(dtype), device_(&defaultDevice), shape_(shape) { device_ = &defaultDevice; - blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_)); + block_ = device_->NewBlock(Product(shape_) * SizeOf(data_type_)); } Tensor::Tensor(Shape &&shape, const DataType dtype) : data_type_(dtype), device_(&defaultDevice), shape_(shape) { device_ = &defaultDevice; - blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_)); + block_ = device_->NewBlock(Product(shape_) * SizeOf(data_type_)); } Tensor::Tensor(const Shape &shape, Device *device, const DataType dtype) : data_type_(dtype), device_(device), shape_(shape) { - blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_)); + block_ = device_->NewBlock(Product(shape_) * SizeOf(data_type_)); } Tensor::Tensor(Shape &&shape, Device *device, const DataType dtype) : data_type_(dtype), device_(device), shape_(shape) { - blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_)); + block_ = device_->NewBlock(Product(shape_) * SizeOf(data_type_)); } Tensor::Tensor(const Tensor &in) : transpose_(in.transpose_), data_type_(in.data_type_), device_(in.device_), - blob_(in.blob()), + block_(in.block()), shape_(in.shape_) { - blob_->IncRefCount(); + block_->IncRefCount(); } Tensor::Tensor(Tensor &&in) @@ -64,40 +65,44 @@ Tensor::Tensor(Tensor &&in) data_type_(in.data_type_), device_(in.device_), shape_(std::move(in.shape_)) { - blob_ = in.blob_; - in.blob_ = nullptr; + block_ = in.block_; + in.block_ = nullptr; } void Tensor::ResetLike(const Tensor &in) { - if (blob_ == nullptr || device_ != in.device_ || MemSize() != in.MemSize()) { - if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_); + if (block_ == nullptr || device_ != in.device_ || MemSize() != in.MemSize()) { + if (block_ != nullptr && block_->DecRefCount() == 0) + device_->FreeBlock(block_); shape_ = in.shape_; device_ = in.device_; data_type_ = in.data_type_; - blob_ = device_->NewBlob(in.MemSize()); + block_ = device_->NewBlock(in.MemSize()); } } void Tensor::Reshape(const Shape &shape) { if (Product(shape_) != Product(shape)) { - if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_); - blob_ = device_->NewBlob(Product(shape) * SizeOf(data_type_)); + if (block_ != nullptr && block_->DecRefCount() == 0) + device_->FreeBlock(block_); + block_ = device_->NewBlock(Product(shape) * SizeOf(data_type_)); } shape_ = shape; } void Tensor::Reshape(Shape &&shape) { if (Product(shape_) != Product(shape)) { - if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_); - blob_ = device_->NewBlob(Product(shape) * SizeOf(data_type_)); + if (block_ != nullptr && block_->DecRefCount() == 0) + device_->FreeBlock(block_); + block_ = device_->NewBlock(Product(shape) * SizeOf(data_type_)); } shape_ = std::move(shape); } void Tensor::AsType(const DataType type) { if (data_type_ != type) { - if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_); - blob_ = device_->NewBlob(Product(shape_) * SizeOf(type)); + if (block_ != nullptr && block_->DecRefCount() == 0) + device_->FreeBlock(block_); + block_ = device_->NewBlock(Product(shape_) * SizeOf(type)); data_type_ = type; } } @@ -107,9 +112,10 @@ void Tensor::ToDevice(Device *dst) { if (device_ != dst) { Tensor tmp(shape_, dst, data_type_); tmp.CopyData(*this); - if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_); - blob_ = tmp.blob_; - tmp.blob_ = nullptr; + if (block_ != nullptr && block_->DecRefCount() == 0) + device_->FreeBlock(block_); + block_ = tmp.block_; + tmp.block_ = nullptr; device_ = dst; } } @@ -122,7 +128,7 @@ void Tensor::CopyDataFromHostPtr(const DType *src, const size_t num) { << "data_type is " << DataType_Name(data_type_) << " user given type is of size " << sizeof(DType); if (src != nullptr) { - device_->CopyDataFromHostPtr(blob(), src, sizeof(DType) * num, 0); + device_->CopyDataFromHostPtr(block(), src, sizeof(DType) * num, 0); } else { LOG(WARNING) << "Copy data from null host ptr"; } @@ -132,9 +138,9 @@ template void Tensor::CopyDataFromHostPtr(const int *src, const size_t num); void Tensor::CopyData(const Tensor &src) { CHECK_EQ(Size(), src.Size()); - CHECK(blob_ != nullptr); - // Do copy only if the src's blob is already initialized. - if (src.blob_ != nullptr) { + CHECK(block_ != nullptr); + // Do copy only if the src's block is already initialized. + if (src.block_ != nullptr) { singa::CopyDataToFrom(this, src, Size(), 0, 0); } } @@ -154,32 +160,34 @@ Tensor Tensor::T() const { t.transpose_ = ~transpose_; t.shape_.push_back(shape_[1]); t.shape_.push_back(shape_[0]); - t.blob_ = blob_; - blob_->IncRefCount(); + t.block_ = block_; + block_->IncRefCount(); return t; } Tensor &Tensor::operator=(const Tensor &in) { // LOG(ERROR) << "= const &"; - if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_); + if (block_ != nullptr && block_->DecRefCount() == 0) + device_->FreeBlock(block_); transpose_ = in.transpose_; data_type_ = in.data_type_; shape_ = in.shape_; device_ = in.device_; - blob_ = in.blob(); - blob_->IncRefCount(); + block_ = in.block(); + block_->IncRefCount(); return *this; } Tensor &Tensor::operator=(Tensor &&in) { // LOG(ERROR) << "= &&"; - if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_); + if (block_ != nullptr && block_->DecRefCount() == 0) + device_->FreeBlock(block_); transpose_ = in.transpose_; data_type_ = in.data_type_; shape_ = std::move(in.shape_); device_ = in.device_; - blob_ = in.blob_; - in.blob_ = nullptr; + block_ = in.block_; + in.block_ = nullptr; return *this; } @@ -233,7 +241,7 @@ void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num, CHECK_GE(dst->MemSize(), d_offset + nBytes); Device *src_dev = src.device(), *dst_dev = dst->device(); - Blob *from = src.blob(), *to = dst->blob(); + Block *from = src.block(), *to = dst->block(); if (dst_dev->lang() != src_dev->lang()) { // let the none cpp device conduct copy op if (dst_dev->lang() == kCpp) { @@ -317,9 +325,9 @@ float Tensor::L2() const { TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, { device_->Exec([&nrm, this](Context *ctx) { DType ret; - Nrm2<DType, Lang>(this->Size(), this->blob(), &ret, ctx); + Nrm2<DType, Lang>(this->Size(), this->block(), &ret, ctx); nrm = TypeCast<DType, float>(ret); - }, {this->blob()}, {}); + }, {this->block()}, {}); }); return nrm; } @@ -327,7 +335,7 @@ template <typename SType> void Tensor::SetValue(const SType x) { CHECK_EQ(sizeof(SType), SizeOf(data_type_)); auto size = Size(); - auto ptr = blob_; + auto ptr = block_; TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, { // cast x to DType device_->Exec([size, x, ptr](Context *ctx) { @@ -341,8 +349,8 @@ template void Tensor::SetValue<float>(const float x); do { \ TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, { \ ret->device()->Exec([t, ret](Context * ctx) { \ - fn<DType, Lang>(t.Size(), t.blob(), ret->blob(), ctx); \ - }, {t.blob()}, {ret->blob()}); \ + fn<DType, Lang>(t.Size(), t.block(), ret->block(), ctx); \ + }, {t.block()}, {ret->block()}); \ }); \ } while (0) @@ -365,14 +373,15 @@ GenUnaryTensorFn(Sqrt); GenUnaryTensorFn(Square); GenUnaryTensorFn(Tanh); -#define EltwiseBinaryTensorFn(fn, lhs, rhs, ret) \ - do { \ - TYPE_LANG_SWITCH(lhs.data_type(), DType, lhs.device()->lang(), Lang, { \ - CHECK_EQ(sizeof(DType), SizeOf(rhs.data_type())); \ - ret->device()->Exec([lhs, rhs, ret](Context * ctx) { \ - fn<DType, Lang>(lhs.Size(), lhs.blob(), rhs.blob(), ret->blob(), ctx); \ - }, {lhs.blob(), rhs.blob()}, {ret->blob()}); \ - }); \ +#define EltwiseBinaryTensorFn(fn, lhs, rhs, ret) \ + do { \ + TYPE_LANG_SWITCH(lhs.data_type(), DType, lhs.device()->lang(), Lang, { \ + CHECK_EQ(sizeof(DType), SizeOf(rhs.data_type())); \ + ret->device()->Exec([lhs, rhs, ret](Context * ctx) { \ + fn<DType, Lang>(lhs.Size(), lhs.block(), rhs.block(), ret->block(), \ + ctx); \ + }, {lhs.block(), rhs.block()}, {ret->block()}); \ + }); \ } while (0) #define GenBinaryTensorFn(op, fn) \ @@ -397,8 +406,8 @@ GenBinaryTensorFn(Pow, Pow); static_assert(std::is_same<SType, DType>::value, \ "The Scalar type must match the Tensor data type"); \ ret->device()->Exec([t, x, ret](Context * ctx) { \ - fn<DType, Lang>(t.Size(), t.blob(), x, ret->blob(), ctx); \ - }, {t.blob()}, {ret->blob()}); \ + fn<DType, Lang>(t.Size(), t.block(), x, ret->block(), ctx); \ + }, {t.block()}, {ret->block()}); \ }); \ } while (0) @@ -440,8 +449,8 @@ void Div(const SType alpha, const Tensor &in, Tensor *out) { TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, { // TODO(wangwei) type cast SType to DType; in.device()->Exec([alpha, in, out](Context *ctx) { - Div<DType, Lang>(in.Size(), alpha, in.blob(), out->blob(), ctx); - }, {in.blob()}, {out->blob()}); + Div<DType, Lang>(in.Size(), alpha, in.block(), out->block(), ctx); + }, {in.block()}, {out->block()}); }); } template void Div<float>(const float, const Tensor &, Tensor *); @@ -474,8 +483,8 @@ float Sum<float>(const Tensor &in) { float s = 0.0f; TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, { in.device()->Exec([in, &s](Context *ctx) { - Sum<DType, Lang>(in.Size(), in.blob(), &s, ctx); - }, {in.blob()}, {}); + Sum<DType, Lang>(in.Size(), in.block(), &s, ctx); + }, {in.block()}, {}); }); return s; } @@ -582,9 +591,9 @@ void MultColumn(const Tensor &v, Tensor *M) { CheckDataTypeAndLang(*M, v); TYPE_LANG_SWITCH(v.data_type(), DType, v.device()->lang(), Lang, { v.device()->Exec([M, v](Context *ctx) { - DGMM<DType, Lang>(false, M->shape(0), M->shape(1), M->blob(), v.blob(), - M->blob(), ctx); - }, {M->blob(), v.blob()}, {M->blob()}); + DGMM<DType, Lang>(false, M->shape(0), M->shape(1), M->block(), v.block(), + M->block(), ctx); + }, {M->block(), v.block()}, {M->block()}); }); } @@ -597,9 +606,9 @@ void MultRow(const Tensor &v, Tensor *M) { CheckDataTypeAndLang(*M, v); TYPE_LANG_SWITCH(v.data_type(), DType, v.device()->lang(), Lang, { v.device()->Exec([M, v](Context *ctx) { - DGMM<DType, Lang>(true, M->shape(0), M->shape(1), M->blob(), v.blob(), - M->blob(), ctx); - }, {M->blob(), v.blob()}, {M->blob()}); + DGMM<DType, Lang>(true, M->shape(0), M->shape(1), M->block(), v.block(), + M->block(), ctx); + }, {M->block(), v.block()}, {M->block()}); }); } @@ -644,8 +653,8 @@ void Bernoulli(const SType p, Tensor *out) { TYPE_LANG_SWITCH(out->data_type(), DType, out->device()->lang(), Lang, { auto prob = TypeCast<SType, DType>(p); out->device()->Exec([prob, out](Context *ctx) { - Bernoulli<DType, Lang>(out->Size(), prob, out->blob(), ctx); - }, {}, {out->blob()}, true); + Bernoulli<DType, Lang>(out->Size(), prob, out->block(), ctx); + }, {}, {out->block()}, true); }); } template void Bernoulli<float>(const float p, Tensor *out); @@ -656,8 +665,8 @@ void Uniform(const SType low, const SType high, Tensor *out) { auto l = TypeCast<SType, DType>(low); auto h = TypeCast<SType, DType>(high); out->device()->Exec([l, h, out](Context *ctx) { - Uniform<DType, Lang>(out->Size(), l, h, out->blob(), ctx); - }, {}, {out->blob()}, true); + Uniform<DType, Lang>(out->Size(), l, h, out->block(), ctx); + }, {}, {out->block()}, true); }); } template void Uniform<float>(const float low, const float high, Tensor *out); @@ -668,8 +677,8 @@ void Gaussian(const SType mean, const SType std, Tensor *out) { auto m = TypeCast<SType, DType>(mean); auto s = TypeCast<SType, DType>(std); out->device()->Exec([m, s, out](Context *ctx) { - Gaussian<DType, Lang>(out->Size(), m, s, out->blob(), ctx); - }, {}, {out->blob()}, true); + Gaussian<DType, Lang>(out->Size(), m, s, out->block(), ctx); + }, {}, {out->block()}, true); }); } template void Gaussian<float>(const float mean, const float std, Tensor *out); @@ -680,8 +689,8 @@ void Axpy(const SType alpha, const Tensor &in, Tensor *out) { TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, { auto a = TypeCast<SType, DType>(alpha); out->device()->Exec([a, in, out](Context *ctx) { - Axpy<DType, Lang>(in.Size(), a, in.blob(), out->blob(), ctx); - }, {in.blob(), out->blob()}, {out->blob()}); + Axpy<DType, Lang>(in.Size(), a, in.block(), out->block(), ctx); + }, {in.block(), out->block()}, {out->block()}); }); } template void Axpy(const float alpha, const Tensor &in, Tensor *out); @@ -708,9 +717,9 @@ void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta, auto a = TypeCast<SType, DType>(alpha); auto b = TypeCast<SType, DType>(beta); C->device()->Exec([a, A, b, B, C](Context *ctx) { - GEMV<DType, Lang>(A.transpose(), A.shape(0), A.shape(1), a, A.blob(), - B.blob(), b, C->blob(), ctx); - }, {A.blob(), B.blob()}, {C->blob()}); + GEMV<DType, Lang>(A.transpose(), A.shape(0), A.shape(1), a, A.block(), + B.block(), b, C->block(), ctx); + }, {A.block(), B.block()}, {C->block()}); }); } else { CHECK(!C->transpose()); @@ -719,13 +728,13 @@ void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta, auto b = TypeCast<SType, DType>(beta); C->device()->Exec([a, A, b, B, C](Context *ctx) { GEMM<DType, Lang>(A.transpose(), B.transpose(), A.shape(0), B.shape(1), - A.shape(1), a, A.blob(), B.blob(), b, C->blob(), ctx); - }, {A.blob(), B.blob()}, {C->blob()}); + A.shape(1), a, A.block(), B.block(), b, C->block(), + ctx); + }, {A.block(), B.block()}, {C->block()}); }); } } - // ************************ // Misc. // *********************** @@ -737,23 +746,22 @@ void ComputeCrossEntropy(const Tensor &p, const Tensor &t, Tensor *loss) { size_t dim = p.Size() / batchsize; TYPE_LANG_SWITCH(p.data_type(), DType, p.device()->lang(), Lang, { p.device()->Exec([batchsize, dim, t, p, loss](Context *ctx) { - ComputeCrossEntropy<DType, Lang>(batchsize, dim, p.blob(), t.blob(), - loss->blob(), ctx); - }, {p.blob(), t.blob()}, {loss->blob()}); + ComputeCrossEntropy<DType, Lang>(batchsize, dim, p.block(), t.block(), + loss->block(), ctx); + }, {p.block(), t.block()}, {loss->block()}); }); } void SoftmaxCrossEntropyBwd(const Tensor &t, Tensor *p) { CHECK_LE(p->nDim(), 2u); CHECK_LE(t.nDim(), 2u); // TODO(wangwei) consider multi-labels. size_t batchsize = 1; - if (p->nDim() == 2u) - batchsize = p->shape(0); + if (p->nDim() == 2u) batchsize = p->shape(0); size_t dim = p->Size() / batchsize; TYPE_LANG_SWITCH(p->data_type(), DType, p->device()->lang(), Lang, { p->device()->Exec([batchsize, dim, t, p](Context *ctx) { - SoftmaxCrossEntropyBwd<DType, Lang>(batchsize, dim, p->blob(), t.blob(), - p->blob(), ctx); - }, {p->blob(), t.blob()}, {p->blob()}); + SoftmaxCrossEntropyBwd<DType, Lang>(batchsize, dim, p->block(), t.block(), + p->block(), ctx); + }, {p->block(), t.block()}, {p->block()}); }); } } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9c2869b9/src/core/tensor/tensor_math.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h index 12490d1..57ccb88 100644 --- a/src/core/tensor/tensor_math.h +++ b/src/core/tensor/tensor_math.h @@ -33,20 +33,20 @@ namespace singa { /// first /// letter. /// 2. Order functions based on function name in alphabetical order. -/// 3. Function arguments order is [const basic type] [const Blob] [mutable -/// Blob]. +/// 3. Function arguments order is [const basic type] [const Block] [mutable +/// Block]. /// 4. Function argument names, use 'num' for total number of elements in -/// elementwise operations; use 'in1' 'in2' for in blobs; use 'out' for -/// output blob or value. With exceptions for some functions, e.g., -/// Scale(const float alpha, const Blob* in, Blob* out); +/// elementwise operations; use 'in1' 'in2' for in blocks; use 'out' for +/// output block or value. With exceptions for some functions, e.g., +/// Scale(const float alpha, const Block* in, Block* out); /// For such cases, use x, v, alpha, etc for scalar types. /// For blas functions, follow the blas style for argument names. /// Use 'M' and 'v' for matrix and vector tensors in functions involving both /// matrix and vectors. -/// 5. For Blob argument xxx, name its raw pointer as xxxPtr. +/// 5. For Block argument xxx, name its raw pointer as xxxPtr. /// 6. Pass the 'cudaStream_t s' to every function in math_kernel.h /// 7. Use size_t for the number of elements, rows or columns. -/// 8. Use the same name for the Tensor and Blob level math functions. +/// 8. Use the same name for the Tensor and Block level math functions. // ************************************** // Element-wise functions @@ -54,41 +54,41 @@ namespace singa { /// out[i] = |in[i]| template <typename DType, typename Lang> -void Abs(const size_t num, const Blob *in, Blob *out, Context *ctx) { +void Abs(const size_t num, const Block *in, Block *out, Context *ctx) { LOG(FATAL) << "Abs Not Implemented"; } /// out[i] = in[i] + x template <typename DType, typename Lang> -void Add(const size_t num, const Blob *in, const DType x, Blob *out, +void Add(const size_t num, const Block *in, const DType x, Block *out, Context *ctx) { LOG(FATAL) << "Add Not Implemented"; } /// out[i] = in1[i] + in2[i] template <typename DType, typename Lang> -void Add(const size_t num, const Blob *in1, const Blob *in2, Blob *out, +void Add(const size_t num, const Block *in1, const Block *in2, Block *out, Context *ctx) { LOG(FATAL) << "Add-Pair Not Implemented"; } /// Clamp every element into [low, high] /// if in[i]>high, then out[i]=high; if in[i]<low, then out[i]=low. template <typename DType, typename Lang> -void Clamp(const size_t num, const DType low, const DType high, const Blob *in, - Blob *out, Context *ctx) { +void Clamp(const size_t num, const DType low, const DType high, const Block *in, + Block *out, Context *ctx) { LOG(FATAL) << "Clamp Not Implemented"; } /// out[i] = x / in[i] template <typename DType, typename Lang> -void Div(const size_t num, const DType x, const Blob *in, Blob *out, +void Div(const size_t num, const DType x, const Block *in, Block *out, Context *ctx) { LOG(FATAL) << "Div Not Implemented"; } /// out[i] = in[i] / x template <typename DType, typename Lang> -void Div(const size_t num, const Blob *in, const DType x, Blob *out, +void Div(const size_t num, const Block *in, const DType x, Block *out, Context *ctx) { CHECK_NE(x, 0.f); EltwiseMult<DType, Lang>(num, in, DType(1) / x, out, ctx); @@ -96,131 +96,131 @@ void Div(const size_t num, const Blob *in, const DType x, Blob *out, /// out[i] = in1[i] / in2[i] template <typename DType, typename Lang> -void Div(const size_t num, const Blob *in1, const Blob *in2, Blob *out, +void Div(const size_t num, const Block *in1, const Block *in2, Block *out, Context *ctx) { LOG(FATAL) << "Div-Pair Not Implemented"; } /// out[i] = in[i] * x template <typename DType, typename Lang> -void EltwiseMult(const size_t num, const Blob *in, const DType x, Blob *out, +void EltwiseMult(const size_t num, const Block *in, const DType x, Block *out, Context *ctx) { LOG(FATAL) << "EltwiseMult Not Implemented"; } /// out[i] = in1[i] * in2[i] template <typename DType, typename Lang> -void EltwiseMult(const size_t num, const Blob *in1, const Blob *in2, Blob *out, +void EltwiseMult(const size_t num, const Block *in1, const Block *in2, Block *out, Context *ctx) { LOG(FATAL) << "EltwiseMult-Pair Not Implemented"; } /// Base is e, Neper number. out[i]=exp(in[i]) template <typename DType, typename Lang> -void Exp(const size_t num, const Blob *in, Blob *out, Context *ctx) { +void Exp(const size_t num, const Block *in, Block *out, Context *ctx) { LOG(FATAL) << "Exp Not Implemented"; } /// out[i]=(in[i]<=x)?1.f:0.f template <typename DType, typename Lang> -void LE(const size_t num, const Blob *in, const DType x, Blob *out, +void LE(const size_t num, const Block *in, const DType x, Block *out, Context *ctx) { LOG(FATAL) << "LE Not Implemented"; } /// Natual logarithm, the base is e, Neper number out[i]=log(in[i]). template <typename DType, typename Lang> -void Log(const size_t num, const Blob *in, Blob *out, Context *ctx) { +void Log(const size_t num, const Block *in, Block *out, Context *ctx) { LOG(FATAL) << "Log Not Implemented"; } /// out[i]=(in[i]<x)?1.f:0.f template <typename DType, typename Lang> -void LT(const size_t num, const Blob *in, const DType x, Blob *out, +void LT(const size_t num, const Block *in, const DType x, Block *out, Context *ctx) { LOG(FATAL) << "LT Not Implemented"; } /// out[i]=(in[i]>=x)?1.f:0.f template <typename DType, typename Lang> -void GE(const size_t num, const Blob *in, const DType x, Blob *out, +void GE(const size_t num, const Block *in, const DType x, Block *out, Context *ctx) { LOG(FATAL) << "GE Not Implemented"; } /// out[i]=(in[i]>x)?1.f:0.f template <typename DType, typename Lang> -void GT(const size_t num, const Blob *in, const DType x, Blob *out, +void GT(const size_t num, const Block *in, const DType x, Block *out, Context *ctx) { LOG(FATAL) << "GT Not Implemented"; } /// out[i] = pow(in[i], x) template <typename DType, typename Lang> -void Pow(const size_t num, const Blob *in, const DType x, Blob *out, +void Pow(const size_t num, const Block *in, const DType x, Block *out, Context *ctx) { LOG(FATAL) << "Pow Not Implemented"; } /// out[i]=pow(in1[i], in2[i]) template <typename DType, typename Lang> -void Pow(const size_t num, const Blob *in1, const Blob *in2, Blob *out, +void Pow(const size_t num, const Block *in1, const Block *in2, Block *out, Context *ctx) { LOG(FATAL) << "Pow-Pair Not Implemented"; } /// out[i]=max(0, in[i]) template <typename DType, typename Lang> -void ReLU(const size_t num, const Blob *in, Blob *out, Context *ctx) { +void ReLU(const size_t num, const Block *in, Block *out, Context *ctx) { LOG(FATAL) << "ReLU Not Implemented"; } /// out[i] = x template <typename DType, typename Lang> -void Set(const size_t num, const DType x, Blob *out, Context *ctx) { +void Set(const size_t num, const DType x, Block *out, Context *ctx) { LOG(FATAL) << "Set Not Implemented"; } /// out[i]=sigmoid(in[i]) template <typename DType, typename Lang> -void Sigmoid(const size_t num, const Blob *in, Blob *out, Context *ctx) { +void Sigmoid(const size_t num, const Block *in, Block *out, Context *ctx) { LOG(FATAL) << "Sigmoid Not Implemented"; } /// out[i] = sign(in[i]) template <typename DType, typename Lang> -void Sign(const size_t num, const Blob *in, Blob *out, Context *ctx) { +void Sign(const size_t num, const Block *in, Block *out, Context *ctx) { LOG(FATAL) << "Sign Not Implemented"; } /// out[i]=sqrt(in[i]) template <typename DType, typename Lang> -void Sqrt(const size_t num, const Blob *in, Blob *out, Context *ctx) { +void Sqrt(const size_t num, const Block *in, Block *out, Context *ctx) { LOG(FATAL) << "Sqrt Not Implemented"; } /// out[i]=square(in[i]) template <typename DType, typename Lang> -void Square(const size_t num, const Blob *in, Blob *out, Context *ctx) { +void Square(const size_t num, const Block *in, Block *out, Context *ctx) { EltwiseMult<DType, Lang>(num, in, in, out, ctx); } /// out[i] = in[i] - x template <typename DType, typename Lang> -void Sub(const size_t num, const Blob *in, const DType x, Blob *out, +void Sub(const size_t num, const Block *in, const DType x, Block *out, Context *ctx) { Add<DType, Lang>(num, in, -x, out, ctx); } /// out[i] = in1[i] - in2[i] template <typename DType, typename Lang> -void Sub(const size_t num, const Blob *in1, const Blob *in2, Blob *out, +void Sub(const size_t num, const Block *in1, const Block *in2, Block *out, Context *ctx) { LOG(FATAL) << "Sub-Pair Not Implemented"; } /// sum all elements of in into out template <typename DType, typename Lang> -void Sum(const size_t num, const Blob *in, DType *out, Context *ctx) { +void Sum(const size_t num, const Block *in, DType *out, Context *ctx) { LOG(FATAL) << "Sum Not Implemented"; } /// out[i]=tanh(in[i]) template <typename DType, typename Lang> -void Tanh(const size_t num, const Blob *in, Blob *out, Context *ctx) { +void Tanh(const size_t num, const Block *in, Block *out, Context *ctx) { LOG(FATAL) << "Tanh Not Implemented"; } @@ -231,20 +231,20 @@ void Tanh(const size_t num, const Blob *in, Blob *out, Context *ctx) { // Get the random generator from 'ctx' // If DType is not float, then convert the threshold to DType template <typename DType, typename Lang> -void Bernoulli(const size_t num, const float p, Blob *out, Context *ctx) { +void Bernoulli(const size_t num, const float p, Block *out, Context *ctx) { LOG(FATAL) << "Bernoulli Not Implemented"; } // The random generator should be extracted from ctx. // If DType is not float, then convert the mean and std to DType template <typename DType, typename Lang> -void Gaussian(const size_t num, const float mean, const float std, Blob *out, +void Gaussian(const size_t num, const float mean, const float std, Block *out, Context *ctx) { LOG(FATAL) << "Gaussian Not Implemented"; } // The random generator should be extracted from ctx. // If DType is not float, then convert the low and high to DType template <typename DType, typename Lang> -void Uniform(const size_t num, const float low, const float high, Blob *out, +void Uniform(const size_t num, const float low, const float high, Block *out, Context *ctx) { LOG(FATAL) << "Uniform Not Implemented"; } @@ -255,43 +255,43 @@ void Uniform(const size_t num, const float low, const float high, Blob *out, /// outurn the index of the element with the max value. template <typename DType, typename Lang> -void Amax(const size_t num, const Blob *in, size_t *out, Context *ctx) { +void Amax(const size_t num, const Block *in, size_t *out, Context *ctx) { LOG(FATAL) << "Amax Not Implemented"; } /// outurn the index of the element with the min value. template <typename DType, typename Lang> -void Amin(const size_t num, const Blob *in, size_t *out, Context *ctx) { +void Amin(const size_t num, const Block *in, size_t *out, Context *ctx) { LOG(FATAL) << "Amin Not Implemented"; } /// out = sum |x| for all x in in template <typename DType, typename Lang> -void Asum(const size_t num, const Blob *in, DType *out, Context *ctx) { +void Asum(const size_t num, const Block *in, DType *out, Context *ctx) { LOG(FATAL) << "Asum Not Implemented"; } /// out = alpha * in + out template <typename DType, typename Lang> -void Axpy(const size_t num, const DType alpha, const Blob *in, Blob *out, +void Axpy(const size_t num, const DType alpha, const Block *in, Block *out, Context *ctx) { LOG(FATAL) << "Axpy Not Implemented"; } /// out = ||in||_2^2, i.e, L2 norm. template <typename DType, typename Lang> -void Nrm2(const size_t num, const Blob *in, float *out, Context *ctx) { +void Nrm2(const size_t num, const Block *in, float *out, Context *ctx) { LOG(FATAL) << "Nrm2 Not Implemented"; } /// out *= x template <typename DType, typename Lang> -void Scale(const size_t num, const DType x, Blob *out, Context *ctx) { +void Scale(const size_t num, const DType x, Block *out, Context *ctx) { LOG(FATAL) << "Scale Not Implemented"; } /// inner product of array in1 and in2 template <typename DType, typename Lang> -void Dot(const size_t num, const Blob *in1, const Blob *in2, DType *out, +void Dot(const size_t num, const Block *in1, const Block *in2, DType *out, Context *ctx) { LOG(FATAL) << "Dot Not Implemented"; } @@ -300,7 +300,7 @@ void Dot(const size_t num, const Blob *in1, const Blob *in2, DType *out, /// transA indicates if the internal data layout is transposed of A template <typename DType, typename Lang> void GEMV(bool trans, const size_t m, const size_t n, const DType alpha, - const Blob *A, const Blob *v, const DType beta, Blob *out, + const Block *A, const Block *v, const DType beta, Block *out, Context *ctx) { LOG(FATAL) << "GEMV Not Implemented"; } @@ -309,7 +309,7 @@ void GEMV(bool trans, const size_t m, const size_t n, const DType alpha, /// if matrix_lef_side is true, do M*v; else do v*M template <typename DType, typename Lang> void DGMM(const bool side_right, const size_t nrow, const size_t ncol, - const Blob *M, const Blob *v, Blob *out, Context *ctx) { + const Block *M, const Block *v, Block *out, Context *ctx) { LOG(FATAL) << "DGMM Not Implemented"; } @@ -318,7 +318,7 @@ void DGMM(const bool side_right, const size_t nrow, const size_t ncol, template <typename DType, typename Lang> void GEMM(const bool transA, const bool transB, const size_t nrowA, const size_t ncolB, const size_t ncolA, const DType alpha, - const Blob *A, const Blob *B, const DType beta, Blob *C, + const Block *A, const Block *B, const DType beta, Block *C, Context *ctx) { LOG(FATAL) << "GEMM Not Implemented"; } @@ -327,14 +327,14 @@ void GEMM(const bool transA, const bool transB, const size_t nrowA, // following the consistency guide. template <typename DType, typename Lang> void ComputeCrossEntropy(const size_t batchsize, const size_t dim, - const Blob *p, const Blob *t, Blob *loss, + const Block *p, const Block *t, Block *loss, Context *ctx) { LOG(FATAL) << "Not Implemented"; } template <typename DType, typename Lang> void SoftmaxCrossEntropyBwd(const size_t batchsize, const size_t dim, - const Blob *p, const Blob *t, Blob *grad, + const Block *p, const Block *t, Block *grad, Context *ctx) { LOG(FATAL) << "Not Implemented"; } @@ -345,40 +345,40 @@ void SoftmaxCrossEntropyBwd(const size_t batchsize, const size_t dim, /* /// Add the vector v to every column of A as the column of out template <typename DType, typename Lang> -void AddCol(const size_t nrow, const size_t ncol, const Blob *A, const Blob *v, - Blob *out, Context *ctx) { +void AddCol(const size_t nrow, const size_t ncol, const Block *A, const Block *v, + Block *out, Context *ctx) { LOG(FATAL) << "AddCol Not Implemented"; } // TODO(wangwei) unify AddRow and AddCol. /// Add the vector v to every row of A as the row of out template <typename DType, typename Lang> -void AddRow(const size_t nrow, const size_t ncol, const Blob *A, const Blob *v, - Blob *out, Context *ctx) { +void AddRow(const size_t nrow, const size_t ncol, const Block *A, const Block *v, + Block *out, Context *ctx) { LOG(FATAL) << "AddRow Not Implemented"; } /// outer-product. /// in1 and in2 are vectors of len m and n. out is matrix of shape m * n template <typename DType, typename Lang> -void Outer(const size_t m, const size_t n, const Blob *in1, const Blob *in2, - Blob *out, Context *ctx) { +void Outer(const size_t m, const size_t n, const Block *in1, const Block *in2, + Block *out, Context *ctx) { LOG(FATAL) << "Outer Not Implemented"; } /// Sum the columns of the in matrix into a vector template <typename DType, typename Lang> -void SumColumns(const size_t nrow, const size_t ncol, const Blob *in, Blob *out, +void SumColumns(const size_t nrow, const size_t ncol, const Block *in, Block *out, Context *ctx) { LOG(FATAL) << "SumColumns Not Implemented"; } template <typename DType, typename Lang> -void Set(const size_t num, const DType x, Blob *out, Context *ctx) { +void Set(const size_t num, const DType x, Block *out, Context *ctx) { LOG(FATAL) << "Not Implemented"; } // TODO(wangwei) unify SumRow and SumCol. /// Sum the rows of the in matrix into a vector template <typename DType, typename Lang> -void SumRows(const size_t nrow, const size_t ncol, const Blob *in, Blob *out, +void SumRows(const size_t nrow, const size_t ncol, const Block *in, Block *out, Context *ctx) { LOG(FATAL) << "SumRows Not Implemented"; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9c2869b9/src/core/tensor/tensor_math_cpp.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h index c5d092b..4717b5f 100644 --- a/src/core/tensor/tensor_math_cpp.h +++ b/src/core/tensor/tensor_math_cpp.h @@ -30,7 +30,7 @@ namespace singa { template <> -void Abs<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out, +void Abs<float, lang::Cpp>(const size_t num, const Block *in, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); @@ -40,8 +40,8 @@ void Abs<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out, } template <> -void Add<float, lang::Cpp>(const size_t num, const Blob *in, const float x, - Blob *out, Context *ctx) { +void Add<float, lang::Cpp>(const size_t num, const Block *in, const float x, + Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); for (size_t i = 0; i < num; i++) { @@ -50,8 +50,8 @@ void Add<float, lang::Cpp>(const size_t num, const Blob *in, const float x, } template <> -void Add<float, lang::Cpp>(const size_t num, const Blob *in1, const Blob *in2, - Blob *out, Context *ctx) { +void Add<float, lang::Cpp>(const size_t num, const Block *in1, const Block *in2, + Block *out, Context *ctx) { // CHECK_EQ(ctx->stream, nullptr); float *outPtr = static_cast<float *>(out->mutable_data()); const float *in1Ptr = static_cast<const float *>(in1->data()); @@ -63,7 +63,7 @@ void Add<float, lang::Cpp>(const size_t num, const Blob *in1, const Blob *in2, template <> void Clamp<float, lang::Cpp>(const size_t num, const float low, - const float high, const Blob *in, Blob *out, + const float high, const Block *in, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); @@ -79,8 +79,8 @@ void Clamp<float, lang::Cpp>(const size_t num, const float low, } template <> -void Div<float, lang::Cpp>(const size_t num, const Blob *in1, const Blob *in2, - Blob *out, Context *ctx) { +void Div<float, lang::Cpp>(const size_t num, const Block *in1, const Block *in2, + Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *in1Ptr = static_cast<const float *>(in1->data()); const float *in2Ptr = static_cast<const float *>(in2->data()); @@ -91,8 +91,8 @@ void Div<float, lang::Cpp>(const size_t num, const Blob *in1, const Blob *in2, } template <> -void Div<float, lang::Cpp>(const size_t num, const float x, const Blob *in, - Blob *out, Context *ctx) { +void Div<float, lang::Cpp>(const size_t num, const float x, const Block *in, + Block *out, Context *ctx) { const float *inPtr = static_cast<const float *>(in->data()); float *outPtr = static_cast<float *>(out->mutable_data()); for (size_t i = 0; i < num; i++) { @@ -102,8 +102,8 @@ void Div<float, lang::Cpp>(const size_t num, const float x, const Blob *in, } template <> -void EltwiseMult<float, lang::Cpp>(const size_t num, const Blob *in, - const float x, Blob *out, Context *ctx) { +void EltwiseMult<float, lang::Cpp>(const size_t num, const Block *in, + const float x, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); for (size_t i = 0; i < num; i++) { @@ -112,8 +112,8 @@ void EltwiseMult<float, lang::Cpp>(const size_t num, const Blob *in, } template <> -void EltwiseMult<float, lang::Cpp>(const size_t num, const Blob *in1, - const Blob *in2, Blob *out, Context *ctx) { +void EltwiseMult<float, lang::Cpp>(const size_t num, const Block *in1, + const Block *in2, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *in1Ptr = static_cast<const float *>(in1->data()); const float *in2Ptr = static_cast<const float *>(in2->data()); @@ -122,7 +122,7 @@ void EltwiseMult<float, lang::Cpp>(const size_t num, const Blob *in1, } } template <> -void Exp<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out, +void Exp<float, lang::Cpp>(const size_t num, const Block *in, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); @@ -132,8 +132,8 @@ void Exp<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out, } template <> -void GE<float, lang::Cpp>(const size_t num, const Blob *in, const float x, - Blob *out, Context *ctx) { +void GE<float, lang::Cpp>(const size_t num, const Block *in, const float x, + Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); for (size_t i = 0; i < num; i++) { @@ -142,8 +142,8 @@ void GE<float, lang::Cpp>(const size_t num, const Blob *in, const float x, } template <> -void GT<float, lang::Cpp>(const size_t num, const Blob *in, const float x, - Blob *out, Context *ctx) { +void GT<float, lang::Cpp>(const size_t num, const Block *in, const float x, + Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); for (size_t i = 0; i < num; i++) { @@ -151,8 +151,8 @@ void GT<float, lang::Cpp>(const size_t num, const Blob *in, const float x, } } template <> -void LE<float, lang::Cpp>(const size_t num, const Blob *in, const float x, - Blob *out, Context *ctx) { +void LE<float, lang::Cpp>(const size_t num, const Block *in, const float x, + Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); for (size_t i = 0; i < num; i++) { @@ -160,7 +160,7 @@ void LE<float, lang::Cpp>(const size_t num, const Blob *in, const float x, } } template <> -void Log<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out, +void Log<float, lang::Cpp>(const size_t num, const Block *in, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); @@ -170,8 +170,8 @@ void Log<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out, } } template <> -void LT<float, lang::Cpp>(const size_t num, const Blob *in, const float x, - Blob *out, Context *ctx) { +void LT<float, lang::Cpp>(const size_t num, const Block *in, const float x, + Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); for (size_t i = 0; i < num; i++) { @@ -179,8 +179,8 @@ void LT<float, lang::Cpp>(const size_t num, const Blob *in, const float x, } } template <> -void Pow<float, lang::Cpp>(const size_t num, const Blob *in, const float x, - Blob *out, Context *ctx) { +void Pow<float, lang::Cpp>(const size_t num, const Block *in, const float x, + Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); for (size_t i = 0; i < num; i++) { @@ -189,8 +189,8 @@ void Pow<float, lang::Cpp>(const size_t num, const Blob *in, const float x, } template <> -void Pow<float, lang::Cpp>(const size_t num, const Blob *in1, const Blob *in2, - Blob *out, Context *ctx) { +void Pow<float, lang::Cpp>(const size_t num, const Block *in1, const Block *in2, + Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *in1Ptr = static_cast<const float *>(in1->data()); const float *in2Ptr = static_cast<const float *>(in2->data()); @@ -199,7 +199,7 @@ void Pow<float, lang::Cpp>(const size_t num, const Blob *in1, const Blob *in2, } } template <> -void ReLU<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out, +void ReLU<float, lang::Cpp>(const size_t num, const Block *in, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); @@ -208,13 +208,13 @@ void ReLU<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out, } } template <> -void Set<float, lang::Cpp>(const size_t num, const float x, Blob *out, +void Set<float, lang::Cpp>(const size_t num, const float x, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); for (size_t i = 0; i < num; i++) outPtr[i] = x; } template <> -void Sigmoid<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out, +void Sigmoid<float, lang::Cpp>(const size_t num, const Block *in, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); @@ -224,7 +224,7 @@ void Sigmoid<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out, } template <> -void Sign<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out, +void Sign<float, lang::Cpp>(const size_t num, const Block *in, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); @@ -234,7 +234,7 @@ void Sign<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out, } template <> -void Sqrt<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out, +void Sqrt<float, lang::Cpp>(const size_t num, const Block *in, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); @@ -245,7 +245,7 @@ void Sqrt<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out, } /* template <> -void Square<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out, +void Square<float, lang::Cpp>(const size_t num, const Block *in, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); @@ -256,8 +256,8 @@ void Square<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out, */ template <> -void Sub<float, lang::Cpp>(const size_t num, const Blob *in1, const Blob *in2, - Blob *out, Context *ctx) { +void Sub<float, lang::Cpp>(const size_t num, const Block *in1, const Block *in2, + Block *out, Context *ctx) { // CHECK_EQ(ctx->stream, nullptr); float *outPtr = static_cast<float *>(out->mutable_data()); const float *in1Ptr = static_cast<const float *>(in1->data()); @@ -270,7 +270,7 @@ void Sub<float, lang::Cpp>(const size_t num, const Blob *in1, const Blob *in2, // sum all elements of input into out // TODO(wangwei) optimize using omp template <> -void Sum<float, lang::Cpp>(const size_t num, const Blob *in, float *out, +void Sum<float, lang::Cpp>(const size_t num, const Block *in, float *out, Context *ctx) { float s = 0.f; const float *inPtr = static_cast<const float *>(in->data()); @@ -281,7 +281,7 @@ void Sum<float, lang::Cpp>(const size_t num, const Blob *in, float *out, } template <> -void Tanh<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out, +void Tanh<float, lang::Cpp>(const size_t num, const Block *in, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); @@ -292,7 +292,7 @@ void Tanh<float, lang::Cpp>(const size_t num, const Blob *in, Blob *out, // ===============Random operations========================================== template <> -void Bernoulli<float, lang::Cpp>(const size_t num, const float p, Blob *out, +void Bernoulli<float, lang::Cpp>(const size_t num, const float p, Block *out, Context *ctx) { std::bernoulli_distribution distribution(p); float *outPtr = static_cast<float *>(out->mutable_data()); @@ -303,7 +303,7 @@ void Bernoulli<float, lang::Cpp>(const size_t num, const float p, Blob *out, template <> void Gaussian<float, lang::Cpp>(const size_t num, const float mean, - const float std, Blob *out, Context *ctx) { + const float std, Block *out, Context *ctx) { std::normal_distribution<float> distribution(mean, std); float *outPtr = static_cast<float *>(out->mutable_data()); for (size_t i = 0; i < num; i++) { @@ -312,7 +312,7 @@ void Gaussian<float, lang::Cpp>(const size_t num, const float mean, } template <> void Uniform<float, lang::Cpp>(const size_t num, const float low, - const float high, Blob *out, Context *ctx) { + const float high, Block *out, Context *ctx) { std::uniform_real_distribution<float> distribution(low, high); float *outPtr = static_cast<float *>(out->mutable_data()); for (size_t i = 0; i < num; i++) { @@ -324,8 +324,8 @@ void Uniform<float, lang::Cpp>(const size_t num, const float low, template <> void DGMM<float, lang::Cpp>(const bool side_right, const size_t nrow, - const size_t ncol, const Blob *M, const Blob *v, - Blob *out, Context *ctx) { + const size_t ncol, const Block *M, const Block *v, + Block *out, Context *ctx) { const float *MPtr = static_cast<const float *>(M->data()); const float *vPtr = static_cast<const float *>(v->data()); float *outPtr = static_cast<float *>(out->mutable_data()); @@ -348,42 +348,42 @@ void DGMM<float, lang::Cpp>(const bool side_right, const size_t nrow, #ifdef USE_CBLAS template <> -void Amax<float, lang::Cpp>(const size_t num, const Blob *in, size_t *out, +void Amax<float, lang::Cpp>(const size_t num, const Block *in, size_t *out, Context *ctx) { const float *inPtr = static_cast<const float *>(in->data()); *out = cblas_isamax(num, inPtr, 1); } template <> -void Asum<float, lang::Cpp>(const size_t num, const Blob *in, float *out, +void Asum<float, lang::Cpp>(const size_t num, const Block *in, float *out, Context *ctx) { const float *inPtr = static_cast<const float *>(in->data()); *out = cblas_sasum(num, inPtr, 1); } template <> -void Axpy<float, lang::Cpp>(const size_t num, const float alpha, const Blob *in, - Blob *out, Context *ctx) { +void Axpy<float, lang::Cpp>(const size_t num, const float alpha, + const Block *in, Block *out, Context *ctx) { const float *inPtr = static_cast<const float *>(in->data()); float *outPtr = static_cast<float *>(out->mutable_data()); cblas_saxpy(num, alpha, inPtr, 1, outPtr, 1); } template <> -void Dot<float, lang::Cpp>(const size_t num, const Blob *in1, const Blob *in2, +void Dot<float, lang::Cpp>(const size_t num, const Block *in1, const Block *in2, float *out, Context *ctx) { const float *in1Ptr = static_cast<const float *>(in1->data()); const float *in2Ptr = static_cast<const float *>(in2->data()); *out = cblas_sdot(num, in1Ptr, 1, in2Ptr, 1); } template <> -void Scale<float, lang::Cpp>(const size_t num, const float x, Blob *out, +void Scale<float, lang::Cpp>(const size_t num, const float x, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); cblas_sscal(num, x, outPtr, 1); } template <> -void Nrm2<float, lang::Cpp>(const size_t num, const Blob *in, float *out, +void Nrm2<float, lang::Cpp>(const size_t num, const Block *in, float *out, Context *ctx) { const float *inPtr = static_cast<const float *>(in->data()); *out = cblas_snrm2(num, inPtr, 1); @@ -391,8 +391,8 @@ void Nrm2<float, lang::Cpp>(const size_t num, const Blob *in, float *out, template <> void GEMV<float, lang::Cpp>(bool trans, const size_t m, const size_t n, - const float alpha, const Blob *A, const Blob *v, - const float beta, Blob *out, Context *ctx) { + const float alpha, const Block *A, const Block *v, + const float beta, Block *out, Context *ctx) { const float *APtr = static_cast<const float *>(A->data()); const float *vPtr = static_cast<const float *>(v->data()); float *outPtr = static_cast<float *>(out->mutable_data()); @@ -409,8 +409,8 @@ template <> void GEMM<float, lang::Cpp>(const bool transA, const bool transB, const size_t nrowA, const size_t ncolB, const size_t ncolA, const float alpha, - const Blob *A, const Blob *B, const float beta, - Blob *C, Context *ctx) { + const Block *A, const Block *B, const float beta, + Block *C, Context *ctx) { auto transa = transA ? CblasTrans : CblasNoTrans; auto transb = transB ? CblasTrans : CblasNoTrans; auto lda = transA ? nrowA : ncolA; @@ -426,7 +426,7 @@ void GEMM<float, lang::Cpp>(const bool transA, const bool transB, #else template <> -void Amax<float, lang::Cpp>(const size_t num, const Blob *in, size_t *out, +void Amax<float, lang::Cpp>(const size_t num, const Block *in, size_t *out, Context *ctx) { size_t maxPos = 0; float maxVal = 0; @@ -442,7 +442,7 @@ void Amax<float, lang::Cpp>(const size_t num, const Blob *in, size_t *out, *out = maxPos; } template <> -void Amin<float, lang::Cpp>(const size_t num, const Blob *in, size_t *out, +void Amin<float, lang::Cpp>(const size_t num, const Block *in, size_t *out, Context *ctx) { size_t minPos = 0; float minVal = 0; @@ -459,7 +459,7 @@ void Amin<float, lang::Cpp>(const size_t num, const Blob *in, size_t *out, } template <> -void Asum<float, lang::Cpp>(const size_t num, const Blob *in, float *out, +void Asum<float, lang::Cpp>(const size_t num, const Block *in, float *out, Context *ctx) { float sum = 0; const float *inPtr = static_cast<const float *>(in->data()); @@ -469,8 +469,8 @@ void Asum<float, lang::Cpp>(const size_t num, const Blob *in, float *out, } template <> -void Axpy<float, lang::Cpp>(const size_t num, const float alpha, const Blob *in, - Blob *out, Context *ctx) { +void Axpy<float, lang::Cpp>(const size_t num, const float alpha, + const Block *in, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); for (size_t i = 0; i < num; i++) { @@ -479,7 +479,7 @@ void Axpy<float, lang::Cpp>(const size_t num, const float alpha, const Blob *in, } template <> -void Scale<float, lang::Cpp>(const size_t num, const float x, Blob *out, +void Scale<float, lang::Cpp>(const size_t num, const float x, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); for (size_t i = 0; i < num; i++) { @@ -488,7 +488,7 @@ void Scale<float, lang::Cpp>(const size_t num, const float x, Blob *out, } template <> -void Dot<float, lang::Cpp>(const size_t num, const Blob *in1, const Blob *in2, +void Dot<float, lang::Cpp>(const size_t num, const Block *in1, const Block *in2, float *out, Context *ctx) { float sum = 0; const float *in1Ptr = static_cast<const float *>(in1->data()); @@ -500,8 +500,8 @@ void Dot<float, lang::Cpp>(const size_t num, const Blob *in1, const Blob *in2, template <> void GEMV<float, lang::Cpp>(bool trans, const size_t m, const size_t n, - const float alpha, const Blob *A, const Blob *v, - const float beta, Blob *out, Context *ctx) { + const float alpha, const Block *A, const Block *v, + const float beta, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *APtr = static_cast<const float *>(A->data()); const float *vPtr = static_cast<const float *>(v->data()); @@ -518,8 +518,8 @@ void GEMV<float, lang::Cpp>(bool trans, const size_t m, const size_t n, #endif // USE_CBLAS template <> void ComputeCrossEntropy<float, lang::Cpp>(const size_t batchsize, - const size_t dim, const Blob *p, - const Blob *t, Blob *loss, + const size_t dim, const Block *p, + const Block *t, Block *loss, Context *ctx) { const float *pPtr = static_cast<const float *>(p->data()); const int *tPtr = static_cast<const int *>(t->data()); @@ -534,9 +534,9 @@ void ComputeCrossEntropy<float, lang::Cpp>(const size_t batchsize, template <> void SoftmaxCrossEntropyBwd<float, lang::Cpp>(const size_t batchsize, - const size_t dim, const Blob *p, - const Blob *t, - Blob *grad, Context *ctx) { + const size_t dim, const Block *p, + const Block *t, Block *grad, + Context *ctx) { CHECK_EQ(p, grad) << "Use the same pointer to optimize performance"; // const float* pPtr = static_cast<const float*>(p->data()); const int *tPtr = static_cast<const int *>(t->data()); @@ -549,12 +549,11 @@ void SoftmaxCrossEntropyBwd<float, lang::Cpp>(const size_t batchsize, } } - // =========Matrix operations ================================================ /* template <> void AddCol<float, lang::Cpp>(const size_t nrow, const size_t ncol, - const Blob *A, const Blob *v, Blob *out, + const Block *A, const Block *v, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *APtr = static_cast<const float *>(A->data()); @@ -569,7 +568,7 @@ void AddCol<float, lang::Cpp>(const size_t nrow, const size_t ncol, template <> void AddRow<float, lang::Cpp>(const size_t nrow, const size_t ncol, - const Blob *A, const Blob *v, Blob *out, + const Block *A, const Block *v, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *APtr = static_cast<const float *>(A->data()); @@ -582,8 +581,8 @@ void AddRow<float, lang::Cpp>(const size_t nrow, const size_t ncol, } } template <> -void Outer<float, lang::Cpp>(const size_t m, const size_t n, const Blob *in1, - const Blob *in2, Blob *out, Context *ctx) { +void Outer<float, lang::Cpp>(const size_t m, const size_t n, const Block *in1, + const Block *in2, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *in1Ptr = static_cast<const float *>(in1->data()); const float *in2Ptr = static_cast<const float *>(in2->data()); @@ -596,7 +595,7 @@ void Outer<float, lang::Cpp>(const size_t m, const size_t n, const Blob *in1, } template <> void Softmax<float, lang::Cpp>(const size_t nrow, const size_t ncol, - const Blob *in, Blob *out, Context *ctx) { + const Block *in, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); float *bPtr = new float[ncol]; @@ -617,7 +616,7 @@ void Softmax<float, lang::Cpp>(const size_t nrow, const size_t ncol, template <> void SumColumns<float, lang::Cpp>(const size_t nrow, const size_t ncol, - const Blob *in, Blob *out, Context *ctx) { + const Block *in, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); for (size_t c = 0; c < ncol; c++) { @@ -633,7 +632,7 @@ void SumColumns<float, lang::Cpp>(const size_t nrow, const size_t ncol, template <> void SumRows<float, lang::Cpp>(const size_t nrow, const size_t ncol, - const Blob *in, Blob *out, Context *ctx) { + const Block *in, Block *out, Context *ctx) { float *outPtr = static_cast<float *>(out->mutable_data()); const float *inPtr = static_cast<const float *>(in->data()); for (size_t r = 0; r < nrow; r++) {
