This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 57c8ca1 Fix engine stop/start (#10911) 57c8ca1 is described below commit 57c8ca1a0a6dae36dc27a9f054041ecce652e4c8 Author: Joshua Z. Zhang <cheungc...@gmail.com> AuthorDate: Tue May 15 11:22:39 2018 -0700 Fix engine stop/start (#10911) * fix engine start/stop * add tests * fix test * fix * fix tests --- python/mxnet/gluon/data/dataloader.py | 2 +- src/engine/naive_engine.cc | 6 +++ src/engine/threaded_engine_pooled.cc | 57 +++++++++++++++++++++-------- tests/cpp/engine/threaded_engine_test.cc | 17 +++++++++ tests/python/unittest/test_engine_import.py | 44 ++++++++++++++++++++++ 5 files changed, 109 insertions(+), 17 deletions(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 7ef18bd..d80a6bf 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -143,7 +143,7 @@ class _MultiWorkerIter(object): self._batchify_fn = batchify_fn self._batch_sampler = batch_sampler self._key_queue = Queue() - self._data_queue = SimpleQueue() + self._data_queue = Queue() if sys.version_info[0] <= 2 else SimpleQueue() self._data_buffer = {} self._rcvd_idx = 0 self._sent_idx = 0 diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index 1fa5306..8196af2 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -63,6 +63,12 @@ class NaiveEngine final : public Engine { #endif } + void Stop() override { + } + + void Start() override { + } + // new variables VarHandle NewVariable() override { size_t v = ++counter_; diff --git a/src/engine/threaded_engine_pooled.cc b/src/engine/threaded_engine_pooled.cc index 074ea4e..574e832 100644 --- a/src/engine/threaded_engine_pooled.cc +++ b/src/engine/threaded_engine_pooled.cc @@ -27,6 +27,7 @@ #include <dmlc/logging.h> #include <dmlc/concurrency.h> #include <cassert> +#include <utility> #include "./threaded_engine.h" #include "./thread_pool.h" #include "./stream_manager.h" @@ -42,14 +43,38 @@ namespace engine { */ class ThreadedEnginePooled : public ThreadedEngine { public: - ThreadedEnginePooled() : - thread_pool_(kNumWorkingThreads, [this]() { ThreadWorker(&task_queue_); }), - io_thread_pool_(1, [this]() { ThreadWorker(&io_task_queue_); }) {} + ThreadedEnginePooled() { + this->Start(); + } ~ThreadedEnginePooled() noexcept(false) { - streams_.Finalize(); - task_queue_.SignalForKill(); - io_task_queue_.SignalForKill(); + StopNoWait(); + } + + void StopNoWait() { + streams_->Finalize(); + task_queue_->SignalForKill(); + io_task_queue_->SignalForKill(); + task_queue_ = nullptr; + io_task_queue_ = nullptr; + thread_pool_ = nullptr; + io_thread_pool_ = nullptr; + streams_ = nullptr; + } + + void Stop() override { + WaitForAll(); + StopNoWait(); + } + + void Start() override { + streams_.reset(new StreamManager<kMaxNumGpus, kNumStreamsPerGpu>()); + task_queue_.reset(new dmlc::ConcurrentBlockingQueue<OprBlock*>()); + io_task_queue_.reset(new dmlc::ConcurrentBlockingQueue<OprBlock*>()); + thread_pool_.reset(new ThreadPool(kNumWorkingThreads, [this]() { + ThreadWorker(task_queue_); })); + io_thread_pool_.reset(new ThreadPool(1, [this]() { + ThreadWorker(io_task_queue_); })); } protected: @@ -71,24 +96,24 @@ class ThreadedEnginePooled : public ThreadedEngine { /*! * \brief Streams. */ - StreamManager<kMaxNumGpus, kNumStreamsPerGpu> streams_; + std::unique_ptr<StreamManager<kMaxNumGpus, kNumStreamsPerGpu>> streams_; /*! * \brief Task queues. */ - dmlc::ConcurrentBlockingQueue<OprBlock*> task_queue_; - dmlc::ConcurrentBlockingQueue<OprBlock*> io_task_queue_; + std::shared_ptr<dmlc::ConcurrentBlockingQueue<OprBlock*>> task_queue_; + std::shared_ptr<dmlc::ConcurrentBlockingQueue<OprBlock*>> io_task_queue_; /*! * \brief Thread pools. */ - ThreadPool thread_pool_; - ThreadPool io_thread_pool_; + std::unique_ptr<ThreadPool> thread_pool_; + std::unique_ptr<ThreadPool> io_thread_pool_; /*! * \brief Worker. * \param task_queue Queue to work on. * * The method to pass to thread pool to parallelize. */ - void ThreadWorker(dmlc::ConcurrentBlockingQueue<OprBlock*>* task_queue) { + void ThreadWorker(std::shared_ptr<dmlc::ConcurrentBlockingQueue<OprBlock*>> task_queue) { OprBlock* opr_block; while (task_queue->Pop(&opr_block)) { DoExecute(opr_block); @@ -110,8 +135,8 @@ class ThreadedEnginePooled : public ThreadedEngine { bool is_copy = (opr_block->opr->prop == FnProperty::kCopyFromGPU || opr_block->opr->prop == FnProperty::kCopyToGPU); auto&& rctx = is_copy - ? streams_.GetIORunContext(opr_block->ctx) - : streams_.GetRunContext(opr_block->ctx); + ? streams_->GetIORunContext(opr_block->ctx) + : streams_->GetRunContext(opr_block->ctx); this->ExecuteOprBlock(rctx, opr_block); } /*! @@ -122,11 +147,11 @@ class ThreadedEnginePooled : public ThreadedEngine { switch (opr_block->opr->prop) { case FnProperty::kCopyFromGPU: case FnProperty::kCopyToGPU: { - io_task_queue_.Push(opr_block); + io_task_queue_->Push(opr_block); break; } default: { - task_queue_.Push(opr_block); + task_queue_->Push(opr_block); break; } } diff --git a/tests/cpp/engine/threaded_engine_test.cc b/tests/cpp/engine/threaded_engine_test.cc index 945c083..92d0958 100644 --- a/tests/cpp/engine/threaded_engine_test.cc +++ b/tests/cpp/engine/threaded_engine_test.cc @@ -121,6 +121,23 @@ double EvaluateWorloads(const std::vector<Workload>& workloads, return dmlc::GetTime() - t; } +TEST(Engine, start_stop) { + const int num_engine = 3; + std::vector<mxnet::Engine*> engine(num_engine); + engine[0] = mxnet::engine::CreateNaiveEngine(); + engine[1] = mxnet::engine::CreateThreadedEnginePooled(); + engine[2] = mxnet::engine::CreateThreadedEnginePerDevice(); + std::string type_names[3] = {"NaiveEngine", "ThreadedEnginePooled", "ThreadedEnginePerDevice"}; + + for (int i = 0; i < num_engine; ++i) { + LOG(INFO) << "Stopping: " << type_names[i]; + engine[i]->Stop(); + LOG(INFO) << "Stopped: " << type_names[i] << " Starting..."; + engine[i]->Start(); + LOG(INFO) << "Started: " << type_names[i] << " Done..."; + } +} + TEST(Engine, RandSumExpr) { std::vector<Workload> workloads; int num_repeat = 5; diff --git a/tests/python/unittest/test_engine_import.py b/tests/python/unittest/test_engine_import.py new file mode 100644 index 0000000..bd34eff --- /dev/null +++ b/tests/python/unittest/test_engine_import.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import sys + +def test_engine_import(): + import mxnet + def test_import(): + version = sys.version_info + if version >= (3, 4): + import importlib + importlib.reload(mxnet) + elif version >= (3, ): + import imp + imp.reload(mxnet) + else: + reload(mxnet) + engine_types = ['', 'NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice'] + + for type in engine_types: + if not type: + os.environ.pop('MXNET_ENGINE_TYPE', None) + else: + os.environ['MXNET_ENGINE_TYPE'] = type + test_import() + +if __name__ == '__main__': + import nose + nose.runmodule() -- To stop receiving notification emails like this one, please contact j...@apache.org.