zackcquic commented on a change in pull request #7952: URL: https://github.com/apache/tvm/pull/7952#discussion_r634817313
########## File path: tests/python/relay/test_pass_manager.py ########## @@ -536,10 +537,12 @@ def test_print_ir(capfd): __TRACE_COUNTER__ = 0 -def _tracer(module, info, is_before): - global __TRACE_COUNTER__ - if bool(is_before): [email protected]_instrument +class MyInstrument: + def run_before_pass(self, module, info): + global __TRACE_COUNTER__ Review comment: Done ########## File path: src/ir/transform.cc ########## @@ -162,170 +164,64 @@ void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_inde PassContext PassContext::Create() { return PassContext(make_object<PassContextNode>()); } -void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const { +void PassContext::InstrumentSetUp() const { auto pass_ctx_node = this->operator->(); - if (pass_ctx_node->trace_func != nullptr) { - pass_ctx_node->trace_func(module, info, is_before); + if (pass_ctx_node->instruments.defined()) { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->SetUp(); + } } } -class ModulePass; - -/*! \brief PassProfile stores profiling information for a given pass and its sub-passes. */ -struct PassProfile { - // TODO(@altanh): expose PassProfile through TVM Object API - using Clock = std::chrono::steady_clock; - using Duration = std::chrono::duration<double, std::micro>; - using Time = std::chrono::time_point<Clock>; - - /*! \brief The name of the pass being profiled. */ - String name; - /*! \brief The time when the pass was entered. */ - Time start; - /*! \brief The time when the pass completed. */ - Time end; - /*! \brief The total duration of the pass, i.e. end - start. */ - Duration duration; - /*! \brief PassProfiles for all sub-passes invoked during the execution of the pass. */ - std::vector<PassProfile> children; - - explicit PassProfile(String name) - : name(name), start(Clock::now()), end(Clock::now()), children() {} - - /*! \brief Gets the PassProfile of the currently executing pass. */ - static PassProfile* Current(); - /*! \brief Pushes a new PassProfile with the given pass name. */ - static void EnterPass(String name); - /*! \brief Pops the current PassProfile. */ - static void ExitPass(); -}; - -struct PassProfileThreadLocalEntry { - /*! \brief The placeholder top-level PassProfile. */ - PassProfile root; - /*! \brief The stack of PassProfiles for nested passes currently running. */ - std::stack<PassProfile*> profile_stack; - /*! \brief Whether or not pass profiling is active. */ - bool active; - - PassProfileThreadLocalEntry() : root("root"), active(false) {} -}; - -/*! \brief Thread local store to hold the pass profiling data. */ -typedef dmlc::ThreadLocalStore<PassProfileThreadLocalEntry> PassProfileThreadLocalStore; - -void PassProfile::EnterPass(String name) { - if (!PassProfileThreadLocalStore::Get()->active) return; - PassProfile* cur = PassProfile::Current(); - cur->children.emplace_back(name); - PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back()); +void PassContext::InstrumentTearDown() const { + auto pass_ctx_node = this->operator->(); + if (pass_ctx_node->instruments.defined()) { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->TearDown(); + } + } } -void PassProfile::ExitPass() { - if (!PassProfileThreadLocalStore::Get()->active) return; - PassProfile* cur = PassProfile::Current(); - ICHECK_NE(cur->name, "root") << "mismatched enter/exit for pass profiling"; - cur->end = std::move(PassProfile::Clock::now()); - cur->duration = std::chrono::duration_cast<PassProfile::Duration>(cur->end - cur->start); - PassProfileThreadLocalStore::Get()->profile_stack.pop(); +bool PassContext::InstrumentBeforePass(const IRModule& ir_module, const PassInfo& pass_info) const { + auto pass_ctx_node = this->operator->(); + if (pass_ctx_node->instruments.defined()) { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + if (!pi->RunBeforePass(ir_module, pass_info)) { + return false; + } + } + return true; + } + return true; } -PassProfile* PassProfile::Current() { - PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); - if (!entry->profile_stack.empty()) { - return entry->profile_stack.top(); - } else { - return &entry->root; +void PassContext::InstrumentAfterPass(const IRModule& ir_module, const PassInfo& pass_info) const { + auto pass_ctx_node = this->operator->(); + if (pass_ctx_node->instruments.defined()) { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->RunAfterPass(ir_module, pass_info); + } } } IRModule Pass::operator()(IRModule mod) const { const PassNode* node = operator->(); ICHECK(node != nullptr); - PassProfile::EnterPass(node->Info()->name); + // PassProfile::EnterPass(node->Info()->name); auto ret = node->operator()(std::move(mod)); - PassProfile::ExitPass(); + // PassProfile::ExitPass(); Review comment: Done ########## File path: src/ir/transform.cc ########## @@ -162,170 +164,64 @@ void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_inde PassContext PassContext::Create() { return PassContext(make_object<PassContextNode>()); } -void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const { +void PassContext::InstrumentSetUp() const { auto pass_ctx_node = this->operator->(); - if (pass_ctx_node->trace_func != nullptr) { - pass_ctx_node->trace_func(module, info, is_before); + if (pass_ctx_node->instruments.defined()) { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->SetUp(); + } } } -class ModulePass; - -/*! \brief PassProfile stores profiling information for a given pass and its sub-passes. */ -struct PassProfile { - // TODO(@altanh): expose PassProfile through TVM Object API - using Clock = std::chrono::steady_clock; - using Duration = std::chrono::duration<double, std::micro>; - using Time = std::chrono::time_point<Clock>; - - /*! \brief The name of the pass being profiled. */ - String name; - /*! \brief The time when the pass was entered. */ - Time start; - /*! \brief The time when the pass completed. */ - Time end; - /*! \brief The total duration of the pass, i.e. end - start. */ - Duration duration; - /*! \brief PassProfiles for all sub-passes invoked during the execution of the pass. */ - std::vector<PassProfile> children; - - explicit PassProfile(String name) - : name(name), start(Clock::now()), end(Clock::now()), children() {} - - /*! \brief Gets the PassProfile of the currently executing pass. */ - static PassProfile* Current(); - /*! \brief Pushes a new PassProfile with the given pass name. */ - static void EnterPass(String name); - /*! \brief Pops the current PassProfile. */ - static void ExitPass(); -}; - -struct PassProfileThreadLocalEntry { - /*! \brief The placeholder top-level PassProfile. */ - PassProfile root; - /*! \brief The stack of PassProfiles for nested passes currently running. */ - std::stack<PassProfile*> profile_stack; - /*! \brief Whether or not pass profiling is active. */ - bool active; - - PassProfileThreadLocalEntry() : root("root"), active(false) {} -}; - -/*! \brief Thread local store to hold the pass profiling data. */ -typedef dmlc::ThreadLocalStore<PassProfileThreadLocalEntry> PassProfileThreadLocalStore; - -void PassProfile::EnterPass(String name) { - if (!PassProfileThreadLocalStore::Get()->active) return; - PassProfile* cur = PassProfile::Current(); - cur->children.emplace_back(name); - PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back()); +void PassContext::InstrumentTearDown() const { + auto pass_ctx_node = this->operator->(); + if (pass_ctx_node->instruments.defined()) { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->TearDown(); + } + } } -void PassProfile::ExitPass() { - if (!PassProfileThreadLocalStore::Get()->active) return; - PassProfile* cur = PassProfile::Current(); - ICHECK_NE(cur->name, "root") << "mismatched enter/exit for pass profiling"; - cur->end = std::move(PassProfile::Clock::now()); - cur->duration = std::chrono::duration_cast<PassProfile::Duration>(cur->end - cur->start); - PassProfileThreadLocalStore::Get()->profile_stack.pop(); +bool PassContext::InstrumentBeforePass(const IRModule& ir_module, const PassInfo& pass_info) const { + auto pass_ctx_node = this->operator->(); + if (pass_ctx_node->instruments.defined()) { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + if (!pi->RunBeforePass(ir_module, pass_info)) { + return false; + } + } + return true; + } + return true; } -PassProfile* PassProfile::Current() { - PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); - if (!entry->profile_stack.empty()) { - return entry->profile_stack.top(); - } else { - return &entry->root; +void PassContext::InstrumentAfterPass(const IRModule& ir_module, const PassInfo& pass_info) const { + auto pass_ctx_node = this->operator->(); + if (pass_ctx_node->instruments.defined()) { + for (instrument::PassInstrument pi : pass_ctx_node->instruments) { + pi->RunAfterPass(ir_module, pass_info); + } } } IRModule Pass::operator()(IRModule mod) const { const PassNode* node = operator->(); ICHECK(node != nullptr); - PassProfile::EnterPass(node->Info()->name); + // PassProfile::EnterPass(node->Info()->name); Review comment: Done -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
