SINGA-171 - Create CppDevice and CudaDevice Implement CudaDevice.
(zhongle) Fix erorrs for cudnn and cuda by adding cuda & cudnn libs to singa_linker_libs. NOTE: set cudnn include path before cuda include path, as some platforms may include cudnn.h in cuda/include, but the cudnn.h is not the one users configured in CMAKE_XXX_PATH. Pass test for cudnn dropout; NOTE: make sure all data in cudnn layers are allocated on device (not on cpu). you can check mem erros by cuda-memcheck ./program pass cpplint. Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/0b4b2e20 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/0b4b2e20 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/0b4b2e20 Branch: refs/heads/master Commit: 0b4b2e20f803d1b890f24e6047912282092c156f Parents: 282712c Author: Wei Wang <[email protected]> Authored: Wed May 18 20:10:45 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Thu May 19 14:19:36 2016 +0800 ---------------------------------------------------------------------- CMakeLists.txt | 3 +- cmake/Cuda.cmake | 7 +- include/singa/core/tensor.h | 5 +- include/singa/model/layer.h | 7 +- src/core/device/cpp_device.cc | 2 +- src/core/device/cuda_device.cc | 15 ++-- src/core/device/device.cc | 8 ++- src/core/tensor/tensor.cc | 8 +-- src/model/layer/cudnn_dropout.cc | 30 +++++--- src/model/layer/cudnn_dropout.h | 8 ++- test/CMakeLists.txt | 3 +- test/singa/test_cpp_device.cc | 2 +- test/singa/test_cudnn_dropout.cc | 127 ++++++++++++++++++++++++++++++++++ test/singa/test_dropout.cc | 16 ++--- test/singa/test_tensor_math.cc | 12 ++-- 15 files changed, 201 insertions(+), 52 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0b4b2e20/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/CMakeLists.txt b/CMakeLists.txt index 8457bf2..2d1a1e6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,7 @@ CMAKE_MINIMUM_REQUIRED(VERSION 2.6) PROJECT(singa) -SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -std=c++11 -DUSE_CUDA -DUSE_CUDNN") +SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -std=c++11") +#message(STATUS "${CMAKE_CXX_FLAGS}") list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Thirdparty) #message(STATUS "module path: ${CMAKE_MODULE_PATH}") http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0b4b2e20/cmake/Cuda.cmake ---------------------------------------------------------------------- diff --git a/cmake/Cuda.cmake b/cmake/Cuda.cmake index e3338af..8780fc6 100644 --- a/cmake/Cuda.cmake +++ b/cmake/Cuda.cmake @@ -7,8 +7,9 @@ endif() set(HAVE_CUDA TRUE) message(STATUS "Found cuda_v${CUDA_VERSION}") -include_directories(SYSTEM ${CUDA_INCLUDE_DIRS}) -list(APPEND SINGA_LINKER_LIBS ${CUDA_CUDART_LIBRARY} ${CUDA_curand_LIBRARY} ${CUDA_CUBLAS_LIBRARIES}) +add_definitions(-DUSE_CUDA) +#message(STATUS "linking: ${CUDA_CUDART_LIBRARY} ${CUDA_curand_LIBRARY} ${CUDA_CUBLAS_LIBRARIES}") + #if(USE_CUDNN) #include(cmake/Modules/Cudnn.cmake) @@ -18,3 +19,5 @@ list(APPEND SINGA_LINKER_LIBS ${CUDA_CUDART_LIBRARY} ${CUDA_curand_LIBRARY} ${CU add_definitions(-DUSE_CUDNN) #endif() +include_directories(SYSTEM ${CUDA_INCLUDE_DIRS}) +list(APPEND SINGA_LINKER_LIBS ${CUDA_CUDART_LIBRARY} ${CUDA_curand_LIBRARY} ${CUDA_CUBLAS_LIBRARIES}) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0b4b2e20/include/singa/core/tensor.h ---------------------------------------------------------------------- diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h index 03bf443..359f1ee 100644 --- a/include/singa/core/tensor.h +++ b/include/singa/core/tensor.h @@ -88,8 +88,8 @@ class Tensor { /// Return immutable Tensor values with given type. template <typename DType> - const DType* data() const { - return static_cast<const DType*> (blob()->data()); + DType data() const { + return static_cast<DType> (blob()->data()); } /// data type, including kFloat16, kFloat32, kInt @@ -111,6 +111,7 @@ class Tensor { /// Return number of total elements size_t Size() const { + CHECK_EQ(blob_->size() % SizeOf(data_type_), 0u); return blob_->size() / SizeOf(data_type_); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0b4b2e20/include/singa/model/layer.h ---------------------------------------------------------------------- diff --git a/include/singa/model/layer.h b/include/singa/model/layer.h index a4c4630..050236a 100644 --- a/include/singa/model/layer.h +++ b/include/singa/model/layer.h @@ -16,12 +16,13 @@ * limitations under the License. */ -#ifndef SINGA_LAYER_H_ -#define SINGA_LAYER_H_ +#ifndef SINGA_MODEL_LAYER_H_ +#define SINGA_MODEL_LAYER_H_ #include <vector> #include <string> #include <stack> +#include <utility> #include "singa/core/tensor.h" #include "singa/proto/layer.pb.h" @@ -191,4 +192,4 @@ class Layer { }; } // namespace singa -#endif // SINGA_LAYER_H_ +#endif // SINGA_MODEL_LAYER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0b4b2e20/src/core/device/cpp_device.cc ---------------------------------------------------------------------- diff --git a/src/core/device/cpp_device.cc b/src/core/device/cpp_device.cc index d0e051e..763156c 100644 --- a/src/core/device/cpp_device.cc +++ b/src/core/device/cpp_device.cc @@ -44,4 +44,4 @@ void CppDevice::CopyToFrom(void* dst, const void* src, size_t nBytes, CopyDirection direction, Context* ctx) { memcpy(dst, src, nBytes); } -} +} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0b4b2e20/src/core/device/cuda_device.cc ---------------------------------------------------------------------- diff --git a/src/core/device/cuda_device.cc b/src/core/device/cuda_device.cc index 1f6de60..9be1a6e 100644 --- a/src/core/device/cuda_device.cc +++ b/src/core/device/cuda_device.cc @@ -16,11 +16,11 @@ * limitations under the License. */ #ifdef USE_CUDA -#include <chrono> #include <cublas_v2.h> #include <cuda.h> #include <cuda_runtime.h> #include <curand.h> +#include <chrono> #include "singa/core/device.h" #include "singa/utils/cuda.h" @@ -47,10 +47,10 @@ CudaDevice::CudaDevice(int id, int num_executors, string scheduler, string vm) : Device(id, num_executors, scheduler, vm) { device_type_ = kCuda; - host_ = nullptr; // TODO(wangwei) add host device - ctx_.stream = NULL; // use the default sync stream + host_ = nullptr; // TODO(wangwei) add host device + ctx_.stream = NULL; // use the default sync stream // TODO(wangwei) create one handle for each steam? - CUBLAS_CHECK(cublasCreate(&ctx_.cublas_handle)); + CUDA_CHECK(cudaSetDevice(FindDevice(0))); // use curandCreateGeneratorHost for CudaHost device CURAND_CHECK( curandCreateGenerator(&ctx_.curand_generator, CURAND_RNG_PSEUDO_DEFAULT)); @@ -58,6 +58,7 @@ CudaDevice::CudaDevice(int id, int num_executors, SetRandSeed(seed); // TODO(wangwei) if one generator per stream, then need diff offset per gen? CURAND_CHECK(curandSetGeneratorOffset(ctx_.curand_generator, 0)); + CUBLAS_CHECK(cublasCreate(&(ctx_.cublas_handle))); #ifdef USE_CUDNN // TODO(wangwei) create one handle for each stream? @@ -86,14 +87,14 @@ void CudaDevice::CopyToFrom(void* dst, const void* src, size_t nBytes, /// Allocate cpu memory. void* CudaDevice::Malloc(int size) { void* ptr = nullptr; - cudaMalloc(&ptr, size); + CUDA_CHECK(cudaMalloc(&ptr, size)); return ptr; } /// Free cpu memory. void CudaDevice::Free(void* ptr) { CHECK_NE(ptr, nullptr); - cudaFree(ptr); + CUDA_CHECK(cudaFree(ptr)); } @@ -152,5 +153,5 @@ int CudaDevice::FindDevice(const int start_id) { } -} +} // namespace singa #endif // USE_CUDA http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0b4b2e20/src/core/device/device.cc ---------------------------------------------------------------------- diff --git a/src/core/device/device.cc b/src/core/device/device.cc index 153637c..73bb5c1 100644 --- a/src/core/device/device.cc +++ b/src/core/device/device.cc @@ -54,8 +54,10 @@ void Device::CopyDataToFrom(Blob* dst, Blob* src, size_t nBytes, int src_offset) { this->Exec( [this, dst, src, nBytes, direct, dst_offset, src_offset](Context* ctx) { - this->CopyToFrom((Byte*)dst->mutable_data() + dst_offset, - (Byte*)src->data() + src_offset, nBytes, direct, ctx); + this->CopyToFrom( + reinterpret_cast<char*>(dst->mutable_data()) + dst_offset, + reinterpret_cast<char*>(src->data()) + src_offset, nBytes, + direct, ctx); }, {src}, {dst}); } @@ -63,7 +65,7 @@ void Device::CopyDataToFrom(Blob* dst, Blob* src, size_t nBytes, void Device::CopyDataFromHostPtr(Blob* dst, const void* src, size_t nBytes, size_t dst_offset) { auto direct = device_type_ == kCpp ? kHostToHost : kHostToDevice; - void* dstptr = (Byte*)dst->mutable_data() + dst_offset; + void* dstptr = reinterpret_cast<char*>(dst->mutable_data()) + dst_offset; Exec([this, dstptr, src, nBytes, direct](Context* ctx) { CopyToFrom(dstptr, src, nBytes, direct, ctx); }, {}, {dst}); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0b4b2e20/src/core/tensor/tensor.cc ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index 339262e..fac846c 100644 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -71,12 +71,12 @@ Tensor::Tensor(Tensor&& t) } void Tensor::ResetLike(const Tensor& t) { - if (blob_ == nullptr || blob_->size() != t.MemSize()) { + if (blob_ == nullptr || device_ != t.device_ || MemSize() != t.MemSize()) { if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_); shape_ = t.shape_; device_ = t.device_; data_type_ = t.data_type_; - blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_)); + blob_ = device_->NewBlob(t.MemSize()); } } @@ -121,8 +121,7 @@ void Tensor::CopyDataFromHostPtr(const DType* src, size_t num) { << "data_type is " << DataType_Name(data_type_) << " user given type is of size " << sizeof(DType); if (src != nullptr) { - auto direction = device_->type() == kCpp ? kHostToHost : kHostToDevice; - device_->CopyDataFromHostPtr(blob(), src, sizeof(DType) * num, direction); + device_->CopyDataFromHostPtr(blob(), src, sizeof(DType) * num, 0); } else { LOG(WARNING) << "Copy data from null host ptr"; } @@ -169,6 +168,7 @@ Tensor& Tensor::operator=(Tensor&& t) { if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_); transpose_ = t.transpose_; + data_type_ = t.data_type_; shape_ = std::move(t.shape_); device_ = t.device_; blob_ = t.blob_; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0b4b2e20/src/model/layer/cudnn_dropout.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_dropout.cc b/src/model/layer/cudnn_dropout.cc index 4d5f5d5..e049ade 100644 --- a/src/model/layer/cudnn_dropout.cc +++ b/src/model/layer/cudnn_dropout.cc @@ -18,9 +18,14 @@ #ifdef USE_CUDNN // cudnn dropout is added in cudnn 5 #if CUDNN_MAJOR_VERSION >= 5 + #include "./cudnn_dropout.h" +#include <cudnn.h> +#include <chrono> + #include "./cudnn_utils.h" #include "singa/utils/logging.h" + namespace singa { CudnnDropout::~CudnnDropout() { if (drop_desc_ != nullptr) @@ -29,7 +34,8 @@ CudnnDropout::~CudnnDropout() { if (y_desc_ != nullptr) CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_desc_)); } -void CudnnDropout::InitCudnn(int size, DataType dtype, Context* ctx) { +void CudnnDropout::InitCudnn(int size, DataType dtype, Device* dev, + Context* ctx) { CHECK(!has_init_cudnn_); CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_)); @@ -41,10 +47,17 @@ void CudnnDropout::InitCudnn(int size, DataType dtype, Context* ctx) { y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), 1, 1, 1, size)); cudnnDropoutGetStatesSize(ctx->cudnn_handle, &state_size_); + state_ = Tensor(Shape{state_size_}, dev, kChar); cudnnDropoutGetReserveSpaceSize(x_desc_, &reserve_size_); + mask_ = Tensor(Shape{reserve_size_}, dev, kChar); + // TODO(wangwei) update for async running, + // where reserve_size_ may not available + CHECK_EQ(reserve_size_, mask_.MemSize()); + + // TODO(wangwei) get seed from ctx or user config? + auto seed = std::chrono::system_clock::now().time_since_epoch().count(); cudnnSetDropoutDescriptor(drop_desc_, ctx->cudnn_handle, 1 - dropout_ratio_, - state_.blob()->mutable_data(), state_size_, - ctx->seed); + state_.blob()->mutable_data(), state_size_, seed); has_init_cudnn_ = true; } @@ -52,16 +65,13 @@ const Tensor CudnnDropout::Forward(int flag, const Tensor& input) { if (flag & kTrain) { auto size = input.Size(); DataType dtype = input.data_type(); + Device* dev = input.device(); if (!has_init_cudnn_) { input.device()->Exec( - [size, dtype, this](Context* ctx) { - this->InitCudnn(size, dtype, ctx); + [size, dtype, this, dev](Context* ctx) { + this->InitCudnn(size, dtype, dev, ctx); }, {}, {this->state_.blob()}); - mask_.ResetLike(input); - // TODO(wangwei) update for async running, - // where reserve_size_ may not available - CHECK_EQ(reserve_size_, mask_.MemSize()); } Tensor output; output.ResetLike(input); @@ -71,7 +81,7 @@ const Tensor CudnnDropout::Forward(int flag, const Tensor& input) { *mblob = mask_.blob(); cudnnDropoutForward(ctx->cudnn_handle, this->drop_desc_, this->x_desc_, inblob->data(), this->y_desc_, - outblob->mutable_data(), mblob, + outblob->mutable_data(), mblob->mutable_data(), this->reserve_size_); }, {input.blob()}, {output.blob(), mask_.blob()}); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0b4b2e20/src/model/layer/cudnn_dropout.h ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_dropout.h b/src/model/layer/cudnn_dropout.h index db0aa15..647eed2 100644 --- a/src/model/layer/cudnn_dropout.h +++ b/src/model/layer/cudnn_dropout.h @@ -21,9 +21,11 @@ #ifdef USE_CUDNN // cudnn dropout is added in cudnn 5 #if CUDNN_MAJOR_VERSION >= 5 +#include <cudnn.h> #include <utility> #include <string> #include <vector> + #include "./dropout.h" #include "singa/core/common.h" #include "singa/model/layer.h" @@ -41,12 +43,12 @@ class CudnnDropout : public Dropout { const Tensor& grad) override; /// Init cudnn related data structures. - void InitCudnn(int size, DataType dtype, Context* ctx); + void InitCudnn(int size, DataType dtype, Device* dev, Context* ctx); private: bool has_init_cudnn_ = false; - cudnnDropoutDescriptor_t drop_desc_; - cudnnTensorDescriptor_t x_desc_, y_desc_; + cudnnDropoutDescriptor_t drop_desc_ = nullptr; + cudnnTensorDescriptor_t x_desc_ = nullptr, y_desc_ = nullptr; size_t state_size_, reserve_size_; Tensor state_; }; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0b4b2e20/test/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f362968..de64abd 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -6,5 +6,6 @@ AUX_SOURCE_DIRECTORY(singa singa_test_source) ADD_EXECUTABLE(test_singa "gtest/gtest_main.cc" ${singa_test_source}) ADD_DEPENDENCIES(test_singa singa_core singa_utils) MESSAGE(STATUS "link libs" ${singa_linker_libs}) -TARGET_LINK_LIBRARIES(test_singa gtest singa_core singa_utils proto protobuf) +TARGET_LINK_LIBRARIES(test_singa gtest singa_core singa_utils proto protobuf + ${SINGA_LINKER_LIBS}) SET_TARGET_PROPERTIES(test_singa PROPERTIES LINK_FLAGS "${LINK_FLAGS} -pthread") http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0b4b2e20/test/singa/test_cpp_device.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_cpp_device.cc b/test/singa/test_cpp_device.cc index d2c0149..c302206 100644 --- a/test/singa/test_cpp_device.cc +++ b/test/singa/test_cpp_device.cc @@ -34,7 +34,7 @@ TEST(CppDevice, MemoryMallocFree) { CppDevice dev(0, 1); Blob* b = dev.NewBlob(4); EXPECT_NE(nullptr, b); - EXPECT_EQ(4, b->size()); + EXPECT_EQ(4u, b->size()); dev.FreeBlob(b); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0b4b2e20/test/singa/test_cudnn_dropout.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_cudnn_dropout.cc b/test/singa/test_cudnn_dropout.cc new file mode 100644 index 0000000..9913074 --- /dev/null +++ b/test/singa/test_cudnn_dropout.cc @@ -0,0 +1,127 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ +#ifdef USE_CUDNN +// cudnn dropout is added in cudnn 5 +//#if CUDNN_MAJOR_VERSION >= 5 + +#include "../src/model/layer/cudnn_dropout.h" +#include "gtest/gtest.h" + +bool inline GetBitValue(const char* x, int pos) { + const unsigned char BitMask[] = {1, 2, 4, 8, 16, 32, 64, 128}; + int idx = pos / 8; + int offset = pos % 8; + return x[idx] & BitMask[offset]; +} + +using singa::CudnnDropout; +TEST(CudnnDropout, Setup) { + CudnnDropout drop; + EXPECT_EQ("CudnnDropout", drop.layer_type()); + + singa::LayerConf conf; + singa::DropoutConf* dropconf = conf.mutable_dropout_conf(); + dropconf->set_dropout_ratio(0.8); + + drop.Setup(conf); + EXPECT_EQ(0.8f, drop.dropout_ratio()); +} + +TEST(CudnnDropout, Forward) { + const float x[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + size_t n = sizeof(x) / sizeof(float); + singa::CudaDevice cuda(0, 1); + singa::Tensor in(singa::Shape{n}, &cuda); + in.CopyDataFromHostPtr(x, n); + + float pdrop = 0.5; + CudnnDropout drop; + singa::LayerConf conf; + singa::DropoutConf* dropconf = conf.mutable_dropout_conf(); + dropconf->set_dropout_ratio(pdrop); + drop.Setup(conf); + + singa::Tensor out1 = drop.Forward(singa::kTrain, in); + + singa::Tensor mask(drop.mask().shape(), drop.mask().data_type()); + mask.CopyData(drop.mask()); + const char* mptr = mask.data<const char*>(); + for (size_t i = 0; i < n; i++) + EXPECT_FLOAT_EQ(0, GetBitValue(mptr, i) * (GetBitValue(mptr, i) - 1)); + + singa::CppDevice host(0, 1); + out1.ToDevice(&host); + const float* outptr1 = out1.data<const float*>(); + EXPECT_EQ(n, out1.Size()); + float scale = 1.0f / (1.0f - pdrop); + // the output value should be 0 or the same as the input + EXPECT_EQ(0.f, outptr1[0] * (outptr1[0] - scale * x[0])); + EXPECT_EQ(0.f, outptr1[1] * (outptr1[1] - scale * x[1])); + EXPECT_EQ(0.f, outptr1[7] * (outptr1[7] - scale * x[7])); + + singa::Tensor out2 = drop.Forward(singa::kEval, in); + out2.ToDevice(&host); + EXPECT_EQ(n, out2.Size()); + const float* outptr2 = out2.data<const float*>(); + // the output value should be the same as the input + EXPECT_EQ(x[0], outptr2[0]); + EXPECT_EQ(x[1], outptr2[1]); + EXPECT_EQ(x[7], outptr2[7]); +} + +TEST(CudnnDropout, Backward) { + const float x[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + size_t n = sizeof(x) / sizeof(float); + singa::CudaDevice cuda(0, 1); + singa::Tensor in(singa::Shape{n}, &cuda); + in.CopyDataFromHostPtr(x, n); + + float pdrop = 0.5; + float scale = 1.0f / (1.0f - pdrop); + + CudnnDropout drop; + singa::LayerConf conf; + singa::DropoutConf* dropconf = conf.mutable_dropout_conf(); + dropconf->set_dropout_ratio(pdrop); + drop.Setup(conf); + singa::Tensor out1 = drop.Forward(singa::kTrain, in); + + const float dy[] = {4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 2.0f, 3.0f}; + singa::Tensor grad(singa::Shape{n}, &cuda); + grad.CopyDataFromHostPtr(dy, n); + + const auto ret = drop.Backward(singa::kTrain, grad); + singa::CppDevice host(0, 1); + singa::Tensor in_grad = ret.first; + in_grad.ToDevice(&host); + const float* dx = in_grad.data<const float*>(); + + singa::Tensor mask(drop.mask().shape(), drop.mask().data_type()); + mask.CopyData(drop.mask()); + const char* mptr = mask.data<const char*>(); + + + EXPECT_FLOAT_EQ(dx[0], dy[0] * GetBitValue(mptr, 0) * scale); + EXPECT_FLOAT_EQ(dx[1], dy[1] * GetBitValue(mptr, 1) * scale); + EXPECT_FLOAT_EQ(dx[7], dy[7] * GetBitValue(mptr, 7) * scale); +} +//#endif // CUDNN_VERSION_MAJOR>=5 +#endif // USE_CUDNN http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0b4b2e20/test/singa/test_dropout.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_dropout.cc b/test/singa/test_dropout.cc index 3190ecd..d648ff8 100644 --- a/test/singa/test_dropout.cc +++ b/test/singa/test_dropout.cc @@ -23,7 +23,7 @@ #include "gtest/gtest.h" using singa::Dropout; -TEST(DropoutLayer, Setup) { +TEST(Dropout, Setup) { Dropout drop; EXPECT_EQ("Dropout", drop.layer_type()); @@ -35,7 +35,7 @@ TEST(DropoutLayer, Setup) { EXPECT_EQ(0.8f, drop.dropout_ratio()); } -TEST(DropoutLayer, Forward) { +TEST(Dropout, Forward) { const float x[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; size_t n = sizeof(x) / sizeof(float); singa::Tensor in(singa::Shape{n}); @@ -51,11 +51,11 @@ TEST(DropoutLayer, Forward) { singa::Tensor out1 = drop.Forward(singa::kTrain, in); - const float* mptr = static_cast<const float*>(drop.mask().blob()->data()); + const float* mptr = drop.mask().data<const float*>(); for (size_t i = 0; i < n; i++) EXPECT_FLOAT_EQ(0, mptr[i] * (mptr[i] - scale)); - const float* outptr1 = static_cast<const float*>(out1.blob()->data()); + const float* outptr1 = out1.data<const float*>(); EXPECT_EQ(n, out1.Size()); // the output value should be 0 or the same as the input EXPECT_EQ(0.f, outptr1[0] * (outptr1[0] - scale * x[0])); @@ -64,14 +64,14 @@ TEST(DropoutLayer, Forward) { singa::Tensor out2 = drop.Forward(singa::kEval, in); EXPECT_EQ(n, out2.Size()); - const float* outptr2 = static_cast<const float*>(out2.blob()->data()); + const float* outptr2 = out2.data<const float*>(); // the output value should be the same as the input EXPECT_EQ(x[0], outptr2[0]); EXPECT_EQ(x[1], outptr2[1]); EXPECT_EQ(x[7], outptr2[7]); } -TEST(DropoutLayer, Backward) { +TEST(Dropout, Backward) { const float x[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; size_t n = sizeof(x) / sizeof(float); singa::Tensor in(singa::Shape{n}); @@ -91,9 +91,9 @@ TEST(DropoutLayer, Backward) { singa::Tensor grad(singa::Shape{n}); grad.CopyDataFromHostPtr(dy, n); - const float* mptr = static_cast<const float*>(drop.mask().blob()->data()); + const float* mptr = drop.mask().data<const float*>(); const auto ret = drop.Backward(singa::kTrain, grad); - const float* dx = static_cast<const float*>(ret.first.blob()->data()); + const float* dx = ret.first.data<const float*>(); EXPECT_FLOAT_EQ(dx[0], dy[0] * (mptr[0] > 0 ? 1.0f : 0.0f) * scale); EXPECT_FLOAT_EQ(dx[1], dy[1] * (mptr[1] > 0) * scale); EXPECT_FLOAT_EQ(dx[7], dy[7] * (mptr[7] > 0) * scale); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0b4b2e20/test/singa/test_tensor_math.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_tensor_math.cc b/test/singa/test_tensor_math.cc index ccd91a0..eee18ec 100644 --- a/test/singa/test_tensor_math.cc +++ b/test/singa/test_tensor_math.cc @@ -23,7 +23,7 @@ class TestTensorMath : public ::testing::Test { TEST_F(TestTensorMath, MemberAddTensor) { Tensor aa = a.Clone(); aa += a; - const float* dptr = aa.data<float>(); + const float* dptr = aa.data<const float*>(); EXPECT_FLOAT_EQ(2.0f, dptr[0]); EXPECT_FLOAT_EQ(4.0f, dptr[1]); EXPECT_FLOAT_EQ(6.0f, dptr[2]); @@ -31,13 +31,13 @@ TEST_F(TestTensorMath, MemberAddTensor) { // check p is initialized to 0 Tensor p(Shape{6}); p += aa; - const float* dptr1 = p.data<float>(); + const float* dptr1 = p.data<const float*>(); EXPECT_FLOAT_EQ(2.0f, dptr1[0]); EXPECT_FLOAT_EQ(4.0f, dptr1[1]); EXPECT_FLOAT_EQ(6.0f, dptr1[2]); a += b; - const float* dptr2 = a.data<float>(); + const float* dptr2 = a.data<const float*>(); EXPECT_FLOAT_EQ(2.1f, dptr2[0]); EXPECT_FLOAT_EQ(4.1f, dptr2[1]); EXPECT_FLOAT_EQ(6.1f, dptr2[2]); @@ -48,21 +48,21 @@ TEST_F(TestTensorMath, MemberAddTensor) { TEST_F(TestTensorMath, AddTensors) { Tensor ret(a.shape(), a.device(), a.data_type()); Add(a, b, &ret); - const float* dptr = ret.data<float>(); + const float* dptr = ret.data<const float*>(); EXPECT_FLOAT_EQ(2.1f, dptr[0]); EXPECT_FLOAT_EQ(4.1f, dptr[1]); EXPECT_FLOAT_EQ(6.1f, dptr[2]); EXPECT_FLOAT_EQ(12.1f, dptr[5]); const Tensor d = a + b; - const float* dptr2 = d.data<float>(); + const float* dptr2 = d.data<const float*>(); EXPECT_FLOAT_EQ(2.1f, dptr2[0]); EXPECT_FLOAT_EQ(4.1f, dptr2[1]); EXPECT_FLOAT_EQ(6.1f, dptr2[2]); EXPECT_FLOAT_EQ(12.1f, dptr2[5]); Add(a, b, &a); - const float* dptr1 = a.data<float>(); + const float* dptr1 = a.data<const float*>(); EXPECT_FLOAT_EQ(2.1f, dptr1[0]); EXPECT_FLOAT_EQ(4.1f, dptr1[1]); EXPECT_FLOAT_EQ(6.1f, dptr1[2]);
