altanh commented on a change in pull request #7952:
URL: https://github.com/apache/tvm/pull/7952#discussion_r634812967
##########
File path: src/ir/transform.cc
##########
@@ -162,170 +164,64 @@ void PassContext::RegisterConfigOption(const char* key,
uint32_t value_type_inde
PassContext PassContext::Create() { return
PassContext(make_object<PassContextNode>()); }
-void PassContext::Trace(const IRModule& module, const PassInfo& info, bool
is_before) const {
+void PassContext::InstrumentSetUp() const {
auto pass_ctx_node = this->operator->();
- if (pass_ctx_node->trace_func != nullptr) {
- pass_ctx_node->trace_func(module, info, is_before);
+ if (pass_ctx_node->instruments.defined()) {
+ for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
+ pi->SetUp();
+ }
}
}
-class ModulePass;
-
-/*! \brief PassProfile stores profiling information for a given pass and its
sub-passes. */
-struct PassProfile {
- // TODO(@altanh): expose PassProfile through TVM Object API
- using Clock = std::chrono::steady_clock;
- using Duration = std::chrono::duration<double, std::micro>;
- using Time = std::chrono::time_point<Clock>;
-
- /*! \brief The name of the pass being profiled. */
- String name;
- /*! \brief The time when the pass was entered. */
- Time start;
- /*! \brief The time when the pass completed. */
- Time end;
- /*! \brief The total duration of the pass, i.e. end - start. */
- Duration duration;
- /*! \brief PassProfiles for all sub-passes invoked during the execution of
the pass. */
- std::vector<PassProfile> children;
-
- explicit PassProfile(String name)
- : name(name), start(Clock::now()), end(Clock::now()), children() {}
-
- /*! \brief Gets the PassProfile of the currently executing pass. */
- static PassProfile* Current();
- /*! \brief Pushes a new PassProfile with the given pass name. */
- static void EnterPass(String name);
- /*! \brief Pops the current PassProfile. */
- static void ExitPass();
-};
-
-struct PassProfileThreadLocalEntry {
- /*! \brief The placeholder top-level PassProfile. */
- PassProfile root;
- /*! \brief The stack of PassProfiles for nested passes currently running. */
- std::stack<PassProfile*> profile_stack;
- /*! \brief Whether or not pass profiling is active. */
- bool active;
-
- PassProfileThreadLocalEntry() : root("root"), active(false) {}
-};
-
-/*! \brief Thread local store to hold the pass profiling data. */
-typedef dmlc::ThreadLocalStore<PassProfileThreadLocalEntry>
PassProfileThreadLocalStore;
-
-void PassProfile::EnterPass(String name) {
- if (!PassProfileThreadLocalStore::Get()->active) return;
- PassProfile* cur = PassProfile::Current();
- cur->children.emplace_back(name);
-
PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back());
+void PassContext::InstrumentTearDown() const {
+ auto pass_ctx_node = this->operator->();
+ if (pass_ctx_node->instruments.defined()) {
+ for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
Review comment:
is there ever a use case where different pass instruments could depend
on each other?
##########
File path: include/tvm/ir/transform.h
##########
@@ -189,12 +182,32 @@ class PassContext : public ObjectRef {
TVM_DLL static PassContext Current();
/*!
- * \brief Apply the tracing functions of the context to the module, with the
info.
- * \param module The IRModule to trace.
+ * \brief Set up for all the instrument implementations.
+ */
+ TVM_DLL void InstrumentSetUp() const;
+
+ /*!
+ * \brief Clean up for all the instrument implementations.
+ */
+ TVM_DLL void InstrumentTearDown() const;
+
+ /*!
+ * \brief Call intrument implementatations before a pass run.
Review comment:
```suggestion
* \brief Call instrument implementations before a pass run.
```
##########
File path: python/tvm/ir/transform.py
##########
@@ -80,9 +88,13 @@ def __init__(
if not isinstance(disabled, (list, tuple)):
raise TypeError("disabled_pass is expected to be the type of " +
"list/tuple/set.")
+ instruments = list(instruments) if instruments else []
+ if not isinstance(instruments, (list, tuple)):
+ raise TypeError("disabled_pass is expected to be the type of " +
"list/tuple/set.")
Review comment:
also correct error message from `disabled_pass` -> `instruments`
##########
File path: src/ir/transform.cc
##########
@@ -162,170 +164,64 @@ void PassContext::RegisterConfigOption(const char* key,
uint32_t value_type_inde
PassContext PassContext::Create() { return
PassContext(make_object<PassContextNode>()); }
-void PassContext::Trace(const IRModule& module, const PassInfo& info, bool
is_before) const {
+void PassContext::InstrumentSetUp() const {
auto pass_ctx_node = this->operator->();
- if (pass_ctx_node->trace_func != nullptr) {
- pass_ctx_node->trace_func(module, info, is_before);
+ if (pass_ctx_node->instruments.defined()) {
+ for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
+ pi->SetUp();
+ }
}
}
-class ModulePass;
-
-/*! \brief PassProfile stores profiling information for a given pass and its
sub-passes. */
-struct PassProfile {
- // TODO(@altanh): expose PassProfile through TVM Object API
- using Clock = std::chrono::steady_clock;
- using Duration = std::chrono::duration<double, std::micro>;
- using Time = std::chrono::time_point<Clock>;
-
- /*! \brief The name of the pass being profiled. */
- String name;
- /*! \brief The time when the pass was entered. */
- Time start;
- /*! \brief The time when the pass completed. */
- Time end;
- /*! \brief The total duration of the pass, i.e. end - start. */
- Duration duration;
- /*! \brief PassProfiles for all sub-passes invoked during the execution of
the pass. */
- std::vector<PassProfile> children;
-
- explicit PassProfile(String name)
- : name(name), start(Clock::now()), end(Clock::now()), children() {}
-
- /*! \brief Gets the PassProfile of the currently executing pass. */
- static PassProfile* Current();
- /*! \brief Pushes a new PassProfile with the given pass name. */
- static void EnterPass(String name);
- /*! \brief Pops the current PassProfile. */
- static void ExitPass();
-};
-
-struct PassProfileThreadLocalEntry {
- /*! \brief The placeholder top-level PassProfile. */
- PassProfile root;
- /*! \brief The stack of PassProfiles for nested passes currently running. */
- std::stack<PassProfile*> profile_stack;
- /*! \brief Whether or not pass profiling is active. */
- bool active;
-
- PassProfileThreadLocalEntry() : root("root"), active(false) {}
-};
-
-/*! \brief Thread local store to hold the pass profiling data. */
-typedef dmlc::ThreadLocalStore<PassProfileThreadLocalEntry>
PassProfileThreadLocalStore;
-
-void PassProfile::EnterPass(String name) {
- if (!PassProfileThreadLocalStore::Get()->active) return;
- PassProfile* cur = PassProfile::Current();
- cur->children.emplace_back(name);
-
PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back());
+void PassContext::InstrumentTearDown() const {
+ auto pass_ctx_node = this->operator->();
+ if (pass_ctx_node->instruments.defined()) {
+ for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
+ pi->TearDown();
+ }
+ }
}
-void PassProfile::ExitPass() {
- if (!PassProfileThreadLocalStore::Get()->active) return;
- PassProfile* cur = PassProfile::Current();
- ICHECK_NE(cur->name, "root") << "mismatched enter/exit for pass profiling";
- cur->end = std::move(PassProfile::Clock::now());
- cur->duration = std::chrono::duration_cast<PassProfile::Duration>(cur->end -
cur->start);
- PassProfileThreadLocalStore::Get()->profile_stack.pop();
+bool PassContext::InstrumentBeforePass(const IRModule& ir_module, const
PassInfo& pass_info) const {
+ auto pass_ctx_node = this->operator->();
+ if (pass_ctx_node->instruments.defined()) {
+ for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
+ if (!pi->RunBeforePass(ir_module, pass_info)) {
+ return false;
+ }
+ }
+ return true;
Review comment:
redundant return
##########
File path: src/tir/ir/transform.cc
##########
@@ -87,9 +87,11 @@ PrimFuncPass::PrimFuncPass(
// Perform Module -> Module optimizations at the PrimFunc level.
IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext&
pass_ctx) const {
- const PassInfo& pass_info = Info();
Review comment:
what's happening here?
##########
File path: python/tvm/ir/instrument.py
##########
@@ -0,0 +1,154 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-argument
+"""Common pass instrumentation across IR variants."""
+import inspect
+import functools
+
+import tvm._ffi
+import tvm.runtime
+
+from . import _ffi_instrument_api
+
+
+@tvm._ffi.register_object("instrument.PassInstrument")
+class PassInstrument(tvm.runtime.Object):
+ """A pass instrument implementation.
+
+ Users don't need to interact with this class directly.
+ Instead, a `PassInstrument` instance should be created through
`pass_instrument`.
+
+ See Also
+ --------
+ `pass_instrument`
+ """
+
+
+def _wrap_class_pass_instrument(pi_cls):
+ """Wrap a python class as pass instrument"""
+
+ class PyPassInstrument(PassInstrument):
+ """Internal wrapper class to create a class instance."""
+
+ def __init__(self, *args, **kwargs):
+ # initialize handle in cass pi_cls creation failed.fg
+ self.handle = None
+ inst = pi_cls(*args, **kwargs)
+
+ # check method declartion within class, if found, wrap it.
+ def create_method(method):
+ if hasattr(inst, method) and inspect.ismethod(getattr(inst,
method)):
+
+ def func(*args):
+ return getattr(inst, method)(*args)
+
+ func.__name__ = "_" + method
+ return func
+ return None
+
+ # create runtime pass instrument object
+ # reister instance's run_before_pass, run_after_pass, set_up and
tear_down method
+ # to it if present.
+ self.__init_handle_by_constructor__(
+ _ffi_instrument_api.NamedPassInstrument,
+ pi_cls.__name__,
+ create_method("run_before_pass"),
+ create_method("run_after_pass"),
+ create_method("set_up"),
+ create_method("tear_down"),
+ )
+
+ self._inst = inst
+
+ def __getattr__(self, name):
+ # fall back to instance attribute if there is not any
+ return self._inst.__getattribute__(name)
+
+ functools.update_wrapper(PyPassInstrument.__init__, pi_cls.__init__)
+ PyPassInstrument.__name__ = pi_cls.__name__
+ PyPassInstrument.__doc__ = pi_cls.__doc__
+ PyPassInstrument.__module__ = pi_cls.__module__
+ return PyPassInstrument
+
+
+def pass_instrument(pi_cls=None):
+ """Decorate a pass instrument.
+
+ Parameters
+ ----------
+ pi_class :
+
+ Examples
+ --------
+ The following code block decorates a pass instrument class.
+
+ .. code-block:: python
+ @tvm.instrument.pass_instrument
+ class SkipPass:
+ def __init__(self, skip_pass_name):
+ self.skip_pass_name = skip_pass_name
+
+ # Uncomment to customize
+ # def set_up(self):
+ # pass
+
+ # Uncomment to customize
+ # def tear_down(self):
+ # pass
+
+ # If pass name contains keyword, skip it by return False. (return
True: not skip)
+ def run_before_pass(self, mod, pass_info):
+ if self.skip_pass_name in pass_info.name:
+ return False
+ return True
+
+ # Uncomment to customize
+ # def run_after_pass(self, mod, pass_info):
+ # pass
+
+ skip_annotate = SkipPass("AnnotateSpans")
+ with tvm.transform.PassContext(instruments=[skip_annotate]):
+ tvm.relay.build(mod, "llvm")
+ """
+
+ def create_pass_instrument(pi_cls):
+ if not inspect.isclass(pi_cls):
+ raise TypeError("pi_cls must be a class")
+
+ return _wrap_class_pass_instrument(pi_cls)
+
+ if pi_cls:
+ return create_pass_instrument(pi_cls)
+ return create_pass_instrument
+
+
+@tvm._ffi.register_object("instrument.PassInstrument")
+class PassesTimeInstrument(tvm.runtime.Object):
Review comment:
could we rename this to something like `PassTimingInstrument`?
##########
File path: src/ir/transform.cc
##########
@@ -162,170 +164,64 @@ void PassContext::RegisterConfigOption(const char* key,
uint32_t value_type_inde
PassContext PassContext::Create() { return
PassContext(make_object<PassContextNode>()); }
-void PassContext::Trace(const IRModule& module, const PassInfo& info, bool
is_before) const {
+void PassContext::InstrumentSetUp() const {
auto pass_ctx_node = this->operator->();
- if (pass_ctx_node->trace_func != nullptr) {
- pass_ctx_node->trace_func(module, info, is_before);
+ if (pass_ctx_node->instruments.defined()) {
+ for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
+ pi->SetUp();
+ }
}
}
-class ModulePass;
-
-/*! \brief PassProfile stores profiling information for a given pass and its
sub-passes. */
-struct PassProfile {
- // TODO(@altanh): expose PassProfile through TVM Object API
- using Clock = std::chrono::steady_clock;
- using Duration = std::chrono::duration<double, std::micro>;
- using Time = std::chrono::time_point<Clock>;
-
- /*! \brief The name of the pass being profiled. */
- String name;
- /*! \brief The time when the pass was entered. */
- Time start;
- /*! \brief The time when the pass completed. */
- Time end;
- /*! \brief The total duration of the pass, i.e. end - start. */
- Duration duration;
- /*! \brief PassProfiles for all sub-passes invoked during the execution of
the pass. */
- std::vector<PassProfile> children;
-
- explicit PassProfile(String name)
- : name(name), start(Clock::now()), end(Clock::now()), children() {}
-
- /*! \brief Gets the PassProfile of the currently executing pass. */
- static PassProfile* Current();
- /*! \brief Pushes a new PassProfile with the given pass name. */
- static void EnterPass(String name);
- /*! \brief Pops the current PassProfile. */
- static void ExitPass();
-};
-
-struct PassProfileThreadLocalEntry {
- /*! \brief The placeholder top-level PassProfile. */
- PassProfile root;
- /*! \brief The stack of PassProfiles for nested passes currently running. */
- std::stack<PassProfile*> profile_stack;
- /*! \brief Whether or not pass profiling is active. */
- bool active;
-
- PassProfileThreadLocalEntry() : root("root"), active(false) {}
-};
-
-/*! \brief Thread local store to hold the pass profiling data. */
-typedef dmlc::ThreadLocalStore<PassProfileThreadLocalEntry>
PassProfileThreadLocalStore;
-
-void PassProfile::EnterPass(String name) {
- if (!PassProfileThreadLocalStore::Get()->active) return;
- PassProfile* cur = PassProfile::Current();
- cur->children.emplace_back(name);
-
PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back());
+void PassContext::InstrumentTearDown() const {
+ auto pass_ctx_node = this->operator->();
+ if (pass_ctx_node->instruments.defined()) {
+ for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
+ pi->TearDown();
+ }
+ }
}
-void PassProfile::ExitPass() {
- if (!PassProfileThreadLocalStore::Get()->active) return;
- PassProfile* cur = PassProfile::Current();
- ICHECK_NE(cur->name, "root") << "mismatched enter/exit for pass profiling";
- cur->end = std::move(PassProfile::Clock::now());
- cur->duration = std::chrono::duration_cast<PassProfile::Duration>(cur->end -
cur->start);
- PassProfileThreadLocalStore::Get()->profile_stack.pop();
+bool PassContext::InstrumentBeforePass(const IRModule& ir_module, const
PassInfo& pass_info) const {
+ auto pass_ctx_node = this->operator->();
+ if (pass_ctx_node->instruments.defined()) {
+ for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
+ if (!pi->RunBeforePass(ir_module, pass_info)) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return true;
}
-PassProfile* PassProfile::Current() {
- PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get();
- if (!entry->profile_stack.empty()) {
- return entry->profile_stack.top();
- } else {
- return &entry->root;
+void PassContext::InstrumentAfterPass(const IRModule& ir_module, const
PassInfo& pass_info) const {
+ auto pass_ctx_node = this->operator->();
+ if (pass_ctx_node->instruments.defined()) {
+ for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
+ pi->RunAfterPass(ir_module, pass_info);
+ }
}
}
IRModule Pass::operator()(IRModule mod) const {
const PassNode* node = operator->();
ICHECK(node != nullptr);
- PassProfile::EnterPass(node->Info()->name);
+ // PassProfile::EnterPass(node->Info()->name);
auto ret = node->operator()(std::move(mod));
- PassProfile::ExitPass();
+ // PassProfile::ExitPass();
return std::move(ret);
}
IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const {
const PassNode* node = operator->();
ICHECK(node != nullptr);
- PassProfile::EnterPass(node->Info()->name);
+ // PassProfile::EnterPass(node->Info()->name);
Review comment:
```suggestion
```
##########
File path: src/ir/transform.cc
##########
@@ -162,170 +164,64 @@ void PassContext::RegisterConfigOption(const char* key,
uint32_t value_type_inde
PassContext PassContext::Create() { return
PassContext(make_object<PassContextNode>()); }
-void PassContext::Trace(const IRModule& module, const PassInfo& info, bool
is_before) const {
+void PassContext::InstrumentSetUp() const {
auto pass_ctx_node = this->operator->();
- if (pass_ctx_node->trace_func != nullptr) {
- pass_ctx_node->trace_func(module, info, is_before);
+ if (pass_ctx_node->instruments.defined()) {
+ for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
+ pi->SetUp();
+ }
}
}
-class ModulePass;
-
-/*! \brief PassProfile stores profiling information for a given pass and its
sub-passes. */
-struct PassProfile {
- // TODO(@altanh): expose PassProfile through TVM Object API
- using Clock = std::chrono::steady_clock;
- using Duration = std::chrono::duration<double, std::micro>;
- using Time = std::chrono::time_point<Clock>;
-
- /*! \brief The name of the pass being profiled. */
- String name;
- /*! \brief The time when the pass was entered. */
- Time start;
- /*! \brief The time when the pass completed. */
- Time end;
- /*! \brief The total duration of the pass, i.e. end - start. */
- Duration duration;
- /*! \brief PassProfiles for all sub-passes invoked during the execution of
the pass. */
- std::vector<PassProfile> children;
-
- explicit PassProfile(String name)
- : name(name), start(Clock::now()), end(Clock::now()), children() {}
-
- /*! \brief Gets the PassProfile of the currently executing pass. */
- static PassProfile* Current();
- /*! \brief Pushes a new PassProfile with the given pass name. */
- static void EnterPass(String name);
- /*! \brief Pops the current PassProfile. */
- static void ExitPass();
-};
-
-struct PassProfileThreadLocalEntry {
- /*! \brief The placeholder top-level PassProfile. */
- PassProfile root;
- /*! \brief The stack of PassProfiles for nested passes currently running. */
- std::stack<PassProfile*> profile_stack;
- /*! \brief Whether or not pass profiling is active. */
- bool active;
-
- PassProfileThreadLocalEntry() : root("root"), active(false) {}
-};
-
-/*! \brief Thread local store to hold the pass profiling data. */
-typedef dmlc::ThreadLocalStore<PassProfileThreadLocalEntry>
PassProfileThreadLocalStore;
-
-void PassProfile::EnterPass(String name) {
- if (!PassProfileThreadLocalStore::Get()->active) return;
- PassProfile* cur = PassProfile::Current();
- cur->children.emplace_back(name);
-
PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back());
+void PassContext::InstrumentTearDown() const {
+ auto pass_ctx_node = this->operator->();
+ if (pass_ctx_node->instruments.defined()) {
+ for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
+ pi->TearDown();
+ }
+ }
}
-void PassProfile::ExitPass() {
- if (!PassProfileThreadLocalStore::Get()->active) return;
- PassProfile* cur = PassProfile::Current();
- ICHECK_NE(cur->name, "root") << "mismatched enter/exit for pass profiling";
- cur->end = std::move(PassProfile::Clock::now());
- cur->duration = std::chrono::duration_cast<PassProfile::Duration>(cur->end -
cur->start);
- PassProfileThreadLocalStore::Get()->profile_stack.pop();
+bool PassContext::InstrumentBeforePass(const IRModule& ir_module, const
PassInfo& pass_info) const {
+ auto pass_ctx_node = this->operator->();
+ if (pass_ctx_node->instruments.defined()) {
+ for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
+ if (!pi->RunBeforePass(ir_module, pass_info)) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return true;
}
-PassProfile* PassProfile::Current() {
- PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get();
- if (!entry->profile_stack.empty()) {
- return entry->profile_stack.top();
- } else {
- return &entry->root;
+void PassContext::InstrumentAfterPass(const IRModule& ir_module, const
PassInfo& pass_info) const {
+ auto pass_ctx_node = this->operator->();
+ if (pass_ctx_node->instruments.defined()) {
+ for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
+ pi->RunAfterPass(ir_module, pass_info);
+ }
}
}
IRModule Pass::operator()(IRModule mod) const {
const PassNode* node = operator->();
ICHECK(node != nullptr);
- PassProfile::EnterPass(node->Info()->name);
+ // PassProfile::EnterPass(node->Info()->name);
auto ret = node->operator()(std::move(mod));
- PassProfile::ExitPass();
+ // PassProfile::ExitPass();
return std::move(ret);
}
IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const {
const PassNode* node = operator->();
ICHECK(node != nullptr);
- PassProfile::EnterPass(node->Info()->name);
+ // PassProfile::EnterPass(node->Info()->name);
auto ret = node->operator()(std::move(mod), pass_ctx);
- PassProfile::ExitPass();
+ // PassProfile::ExitPass();
Review comment:
```suggestion
```
##########
File path: src/ir/transform.cc
##########
@@ -464,12 +360,19 @@ IRModule ModulePassNode::operator()(IRModule mod, const
PassContext& pass_ctx) c
<< "The diagnostic context was set at the top of this block this is a
bug.";
const PassInfo& pass_info = Info();
+ ICHECK(mod.defined()) << "The input module must be set.";
+
+ if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) {
+ DLOG(INFO) << "Skipping function pass : " << pass_info->name
Review comment:
```suggestion
DLOG(INFO) << "Skipping module pass : " << pass_info->name
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]