This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch pass_callback_via_cx
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/pass_callback_via_cx by this
push:
new 1f38493 Implement pass tracing API
1f38493 is described below
commit 1f3849378884be54c16efcd23c78f168c5c99ecd
Author: Jared Roesch <[email protected]>
AuthorDate: Mon Jan 27 16:35:45 2020 -0800
Implement pass tracing API
---
include/tvm/ir/transform.h | 19 +++++++++++++++++++
python/tvm/relay/transform.py | 13 +++++++++----
src/ir/transform.cc | 8 ++++++++
src/relay/ir/transform.cc | 3 ++-
tests/python/relay/test_pass_manager.py | 30 ++++++++++++++++++++++++++++++
5 files changed, 68 insertions(+), 5 deletions(-)
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index c606b34..03aba40 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -65,6 +65,14 @@
namespace tvm {
namespace transform {
+// Forward declare for TraceFunc.
+class PassInfo;
+
+/*! \brief A callback for tracing passes, useful for debugging and logging.
+ *
+ */
+using TraceFunc = runtime::TypedPackedFunc<void(const IRModule& ir_module,
const PassInfo& ctx, bool is_before)>;
+
/*!
* \brief PassContextNode contains the information that a pass can rely on,
* such as analysis results.
@@ -88,6 +96,8 @@ class PassContextNode : public Object {
/*! \brief The list of disabled passes. */
Array<PrimExpr> disabled_pass;
+ TraceFunc trace_func;
+
PassContextNode() = default;
void VisitAttrs(AttrVisitor* v) {
@@ -101,6 +111,7 @@ class PassContextNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
};
+
/*!
* \brief PassContext that is used to configure the pass behavior.
*
@@ -146,6 +157,14 @@ class PassContext : public ObjectRef {
*/
TVM_DLL static PassContext Current();
+ /*!
+ * \brief Apply the tracing functions of the context to the module, with the
info.
+ * \param module The IRModule to trace.
+ * \param info The pass information.
+ * \param is_before Indicated whether the tracing is before or after a pass.
+ */
+ TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool
is_before) const;
+
// accessor.
using ContainerType = PassContextNode;
class Internal;
diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py
index c4fbde6..26b20e0 100644
--- a/python/tvm/relay/transform.py
+++ b/python/tvm/relay/transform.py
@@ -78,7 +78,8 @@ class PassContext(RelayNode):
opt_level=2,
fallback_device=_nd.cpu(),
required_pass=None,
- disabled_pass=None):
+ disabled_pass=None,
+ trace=None):
if isinstance(fallback_device, str):
fallback_device = _nd.context(fallback_device).device_type
elif isinstance(fallback_device, TVMContext):
@@ -99,7 +100,7 @@ class PassContext(RelayNode):
self.__init_handle_by_constructor__(_transform.PassContext, opt_level,
fallback_device, required,
- disabled)
+ disabled, trace)
def __enter__(self):
_transform.EnterPassContext(self)
@@ -117,7 +118,8 @@ class PassContext(RelayNode):
def build_config(opt_level=2,
fallback_device=_nd.cpu(),
required_pass=None,
- disabled_pass=None):
+ disabled_pass=None,
+ trace=None):
"""Configure the build behavior by setting config variables.
Parameters
@@ -151,13 +153,16 @@ def build_config(opt_level=2,
disabled_pass: set of str, optional
Optimization passes to be disabled during optimization.
+ trace: Callable[[IRModule, PassInfo, bool], None]
+ A tracing function for debugging or introspection.
+
Returns
-------
pass_context: PassContext
The pass context for optimizations.
"""
return PassContext(opt_level, fallback_device, required_pass,
- disabled_pass)
+ disabled_pass, trace)
@register_relay_node
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index 1da010c..d14a5b4 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -84,6 +84,10 @@ PassContext PassContext::Create() {
return PassContext(make_object<PassContextNode>());
}
+void PassContext::Trace(const IRModule& module, const PassInfo& info, bool
is_before) const {
+ this->operator->()->trace_func(module, info, is_before);
+}
+
class ModulePass;
/*!
@@ -231,8 +235,10 @@ IRModule ModulePassNode::operator()(const IRModule& mod,
<< " with opt level: "
<< pass_info->opt_level;
CHECK(mod.defined());
+ pass_ctx.Trace(mod, pass_info, true);
IRModule updated_mod = pass_func(mod, pass_ctx);
CHECK(updated_mod.defined());
+ pass_ctx.Trace(updated_mod, pass_info, true);
return updated_mod;
}
@@ -414,10 +420,12 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext")
int fallback_device = args[1];
tvm::Array<tvm::PrimExpr> required = args[2];
tvm::Array<tvm::PrimExpr> disabled = args[3];
+ TraceFunc trace_func = args[4];
pctx->opt_level = opt_level;
pctx->fallback_device = fallback_device;
pctx->required_pass = std::move(required);
pctx->disabled_pass = std::move(disabled);
+ pctx->trace_func = std::move(trace_func);
*ret = pctx;
});
diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc
index ac0f36c..516103f 100644
--- a/src/relay/ir/transform.cc
+++ b/src/relay/ir/transform.cc
@@ -116,7 +116,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
<< pass_info->name
<< " with opt level: "
<< pass_info->opt_level;
-
+ pass_ctx.Trace(mod, pass_info, true);
// Execute the pass function and return a new module.
IRModule updated_mod = IRModule(mod->functions, mod->type_definitions,
mod->Imports());
std::vector<std::pair<GlobalVar, Function> > updates;
@@ -134,6 +134,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
for (const auto& pair : updates) {
updated_mod->Add(pair.first, pair.second, true);
}
+ pass_ctx.Trace(updated_mod, pass_info, true);
return updated_mod;
}
diff --git a/tests/python/relay/test_pass_manager.py
b/tests/python/relay/test_pass_manager.py
index e02e917..d9e17a3 100644
--- a/tests/python/relay/test_pass_manager.py
+++ b/tests/python/relay/test_pass_manager.py
@@ -522,6 +522,36 @@ def test_print_ir(capfd):
assert "Dumping the module IR" in out
assert "multiply" in out
+__TRACE_COUNTER__ = 0
+
+def _tracer(module, info, is_before):
+ global __TRACE_COUNTER__
+ if is_before:
+ __TRACE_COUNTER__ += 1
+
+def test_print_debug_callback():
+ global __TRACE_COUNTER__
+ shape = (1, 2, 3)
+ tp = relay.TensorType(shape, "float32")
+ x = relay.var("x", tp)
+ y = relay.add(x, x)
+ y = relay.multiply(y, relay.const(2, "float32"))
+ func = relay.Function([x], y)
+
+ seq = _transform.Sequential([
+ relay.transform.InferType(),
+ relay.transform.FoldConstant(),
+ relay.transform.DeadCodeElimination()
+ ])
+
+ assert __TRACE_COUNTER__ == 0
+ mod = relay.Module({"main": func})
+
+ with relay.build_config(opt_level=3, trace=_tracer):
+ mod = seq(mod)
+
+ assert __TRACE_COUNTER__ == 4
+
if __name__ == "__main__":
pytest.main()