cjolivier01 closed pull request #9300: bugfix for parallel rand generator on
multi-gpu
URL: https://github.com/apache/incubator-mxnet/pull/9300
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/include/mxnet/resource.h b/include/mxnet/resource.h
index 773baf04c1..385573259f 100644
--- a/include/mxnet/resource.h
+++ b/include/mxnet/resource.h
@@ -97,7 +97,7 @@ struct Resource {
* \brief Get parallel random number generator.
* \tparam xpu the device type of random number generator.
* \tparam DType the return type.
- * \return the native random number generator. for gpu, it is allocated on
global memory.
+ * \return the parallel random number generator. for gpu, it is allocated on
global memory.
*/
template<typename xpu, typename DType>
inline common::random::RandGenerator<xpu, DType>* get_parallel_random()
const {
diff --git a/src/resource.cc b/src/resource.cc
index e195006c36..c2b260985a 100644
--- a/src/resource.cc
+++ b/src/resource.cc
@@ -90,26 +90,26 @@ class ResourceManagerImpl : public ResourceManager {
: global_seed_(0) {
cpu_temp_space_copy_ = dmlc::GetEnv("MXNET_CPU_TEMP_COPY", 4);
gpu_temp_space_copy_ = dmlc::GetEnv("MXNET_GPU_TEMP_COPY", 1);
- cpu_native_rand_copy_ = dmlc::GetEnv("MXNET_CPU_NATIVE_RAND_COPY", 1);
- gpu_native_rand_copy_ = dmlc::GetEnv("MXNET_GPU_NATIVE_RAND_COPY", 4);
+ cpu_native_rand_copy_ = dmlc::GetEnv("MXNET_CPU_PARALLEL_RAND_COPY", 1);
+ gpu_native_rand_copy_ = dmlc::GetEnv("MXNET_GPU_PARALLEL_RAND_COPY", 4);
engine_ref_ = Engine::_GetSharedRef();
storage_ref_ = Storage::_GetSharedRef();
cpu_rand_.reset(new ResourceRandom<cpu>(
Context::CPU(), global_seed_));
cpu_space_.reset(new ResourceTempSpace(
Context::CPU(), cpu_temp_space_copy_));
- cpu_native_rand_.reset(new ResourceNativeRandom<cpu>(
+ cpu_parallel_rand_.reset(new ResourceParallelRandom<cpu>(
Context::CPU(), cpu_native_rand_copy_, global_seed_));
}
~ResourceManagerImpl() {
// need explicit delete, before engine get killed
cpu_rand_.reset(nullptr);
cpu_space_.reset(nullptr);
- cpu_native_rand_.reset(nullptr);
+ cpu_parallel_rand_.reset(nullptr);
#if MXNET_USE_CUDA
gpu_rand_.Clear();
gpu_space_.Clear();
- gpu_native_rand_.Clear();
+ gpu_parallel_rand_.Clear();
#endif
if (engine_ref_ != nullptr) {
engine_ref_ = nullptr;
@@ -125,7 +125,7 @@ class ResourceManagerImpl : public ResourceManager {
switch (req.type) {
case ResourceRequest::kRandom: return cpu_rand_->resource;
case ResourceRequest::kTempSpace: return cpu_space_->GetNext();
- case ResourceRequest::kParallelRandom: return
cpu_native_rand_->GetNext();
+ case ResourceRequest::kParallelRandom: return
cpu_parallel_rand_->GetNext();
default: LOG(FATAL) << "Unknown supported type " << req.type;
}
} else {
@@ -143,8 +143,8 @@ class ResourceManagerImpl : public ResourceManager {
})->GetNext();
}
case ResourceRequest::kParallelRandom: {
- return gpu_native_rand_.Get(ctx.dev_id, [ctx, this]() {
- return new ResourceNativeRandom<gpu>(ctx, gpu_native_rand_copy_,
global_seed_);
+ return gpu_parallel_rand_.Get(ctx.dev_id, [ctx, this]() {
+ return new ResourceParallelRandom<gpu>(ctx, gpu_native_rand_copy_,
global_seed_);
})->GetNext();
}
default: LOG(FATAL) << "Unknown supported type " << req.type;
@@ -160,12 +160,12 @@ class ResourceManagerImpl : public ResourceManager {
void SeedRandom(uint32_t seed) override {
global_seed_ = seed;
cpu_rand_->Seed(global_seed_);
- cpu_native_rand_->Seed(global_seed_);
+ cpu_parallel_rand_->Seed(global_seed_);
#if MXNET_USE_CUDA
gpu_rand_.ForEach([seed](size_t i, ResourceRandom<gpu> *p) {
p->Seed(seed);
});
- gpu_native_rand_.ForEach([seed](size_t i, ResourceNativeRandom<gpu> *p) {
+ gpu_parallel_rand_.ForEach([seed](size_t i, ResourceParallelRandom<gpu>
*p) {
p->Seed(seed);
});
#endif
@@ -260,9 +260,10 @@ class ResourceManagerImpl : public ResourceManager {
}
};
- // the native random sampler resources
+ // the parallel random sampler resources
+ // it use device API for GPU
template<typename xpu>
- struct ResourceNativeRandom {
+ struct ResourceParallelRandom {
/*! \brief the context of the PRNG */
Context ctx;
/*! \brief pointers to sampler */
@@ -272,24 +273,24 @@ class ResourceManagerImpl : public ResourceManager {
/*! \brief current pointer to the round roubin allocator */
std::atomic<size_t> curr_ptr;
/*! \brief constructor */
- explicit ResourceNativeRandom(Context ctx, size_t ncopy, uint32_t
global_seed)
+ explicit ResourceParallelRandom(Context ctx, size_t ncopy, uint32_t
global_seed)
: ctx(ctx), sampler(ncopy), resource(ncopy), curr_ptr(0) {
for (size_t i = 0; i < sampler.size(); ++i) {
const uint32_t seed = ctx.dev_id + i * kMaxNumGPUs + global_seed *
kRandMagic;
resource[i].var = Engine::Get()->NewVariable();
common::random::RandGenerator<xpu> *r = new
common::random::RandGenerator<xpu>();
- common::random::RandGenerator<xpu>::AllocState(r);
Engine::Get()->PushSync(
[r, seed](RunContext rctx) {
+ common::random::RandGenerator<xpu>::AllocState(r);
r->Seed(rctx.get_stream<xpu>(), seed);
}, ctx, {}, {resource[i].var},
- FnProperty::kNormal, 0,
PROFILER_MESSAGE("ResourceNativeRandomSetSeed"));
+ FnProperty::kNormal, 0,
PROFILER_MESSAGE("ResourceParallelRandomSetSeed"));
sampler[i] = r;
resource[i].ptr_ = sampler[i];
resource[i].req = ResourceRequest(ResourceRequest::kParallelRandom);
}
}
- ~ResourceNativeRandom() {
+ ~ResourceParallelRandom() {
for (size_t i = 0; i < sampler.size(); ++i) {
common::random::RandGenerator<xpu> *r = sampler[i];
Engine::Get()->DeleteVariable(
@@ -345,15 +346,15 @@ class ResourceManagerImpl : public ResourceManager {
std::unique_ptr<ResourceRandom<cpu> > cpu_rand_;
/*! \brief CPU temp space resources */
std::unique_ptr<ResourceTempSpace> cpu_space_;
- /*! \brief CPU native random number resources */
- std::unique_ptr<ResourceNativeRandom<cpu> > cpu_native_rand_;
+ /*! \brief CPU parallel random number resources */
+ std::unique_ptr<ResourceParallelRandom<cpu> > cpu_parallel_rand_;
#if MXNET_USE_CUDA
/*! \brief random number generator for GPU */
common::LazyAllocArray<ResourceRandom<gpu> > gpu_rand_;
/*! \brief temp space for GPU */
common::LazyAllocArray<ResourceTempSpace> gpu_space_;
- /*! \brief GPU native (on device) random number resources */
- common::LazyAllocArray<ResourceNativeRandom<gpu> > gpu_native_rand_;
+ /*! \brief GPU parallel (on device) random number resources */
+ common::LazyAllocArray<ResourceParallelRandom<gpu> > gpu_parallel_rand_;
#endif
};
} // namespace resource
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services