This is an automated email from the ASF dual-hosted git repository. jroesch pushed a commit to branch pass_callback_via_cx in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/pass_callback_via_cx by this push: new 1f38493 Implement pass tracing API 1f38493 is described below commit 1f3849378884be54c16efcd23c78f168c5c99ecd Author: Jared Roesch <jroe...@octoml.ai> AuthorDate: Mon Jan 27 16:35:45 2020 -0800 Implement pass tracing API --- include/tvm/ir/transform.h | 19 +++++++++++++++++++ python/tvm/relay/transform.py | 13 +++++++++---- src/ir/transform.cc | 8 ++++++++ src/relay/ir/transform.cc | 3 ++- tests/python/relay/test_pass_manager.py | 30 ++++++++++++++++++++++++++++++ 5 files changed, 68 insertions(+), 5 deletions(-) diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index c606b34..03aba40 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -65,6 +65,14 @@ namespace tvm { namespace transform { +// Forward declare for TraceFunc. +class PassInfo; + +/*! \brief A callback for tracing passes, useful for debugging and logging. + * + */ +using TraceFunc = runtime::TypedPackedFunc<void(const IRModule& ir_module, const PassInfo& ctx, bool is_before)>; + /*! * \brief PassContextNode contains the information that a pass can rely on, * such as analysis results. @@ -88,6 +96,8 @@ class PassContextNode : public Object { /*! \brief The list of disabled passes. */ Array<PrimExpr> disabled_pass; + TraceFunc trace_func; + PassContextNode() = default; void VisitAttrs(AttrVisitor* v) { @@ -101,6 +111,7 @@ class PassContextNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object); }; + /*! * \brief PassContext that is used to configure the pass behavior. * @@ -146,6 +157,14 @@ class PassContext : public ObjectRef { */ TVM_DLL static PassContext Current(); + /*! + * \brief Apply the tracing functions of the context to the module, with the info. + * \param module The IRModule to trace. + * \param info The pass information. + * \param is_before Indicated whether the tracing is before or after a pass. + */ + TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const; + // accessor. using ContainerType = PassContextNode; class Internal; diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index c4fbde6..26b20e0 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -78,7 +78,8 @@ class PassContext(RelayNode): opt_level=2, fallback_device=_nd.cpu(), required_pass=None, - disabled_pass=None): + disabled_pass=None, + trace=None): if isinstance(fallback_device, str): fallback_device = _nd.context(fallback_device).device_type elif isinstance(fallback_device, TVMContext): @@ -99,7 +100,7 @@ class PassContext(RelayNode): self.__init_handle_by_constructor__(_transform.PassContext, opt_level, fallback_device, required, - disabled) + disabled, trace) def __enter__(self): _transform.EnterPassContext(self) @@ -117,7 +118,8 @@ class PassContext(RelayNode): def build_config(opt_level=2, fallback_device=_nd.cpu(), required_pass=None, - disabled_pass=None): + disabled_pass=None, + trace=None): """Configure the build behavior by setting config variables. Parameters @@ -151,13 +153,16 @@ def build_config(opt_level=2, disabled_pass: set of str, optional Optimization passes to be disabled during optimization. + trace: Callable[[IRModule, PassInfo, bool], None] + A tracing function for debugging or introspection. + Returns ------- pass_context: PassContext The pass context for optimizations. """ return PassContext(opt_level, fallback_device, required_pass, - disabled_pass) + disabled_pass, trace) @register_relay_node diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 1da010c..d14a5b4 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -84,6 +84,10 @@ PassContext PassContext::Create() { return PassContext(make_object<PassContextNode>()); } +void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const { + this->operator->()->trace_func(module, info, is_before); +} + class ModulePass; /*! @@ -231,8 +235,10 @@ IRModule ModulePassNode::operator()(const IRModule& mod, << " with opt level: " << pass_info->opt_level; CHECK(mod.defined()); + pass_ctx.Trace(mod, pass_info, true); IRModule updated_mod = pass_func(mod, pass_ctx); CHECK(updated_mod.defined()); + pass_ctx.Trace(updated_mod, pass_info, true); return updated_mod; } @@ -414,10 +420,12 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext") int fallback_device = args[1]; tvm::Array<tvm::PrimExpr> required = args[2]; tvm::Array<tvm::PrimExpr> disabled = args[3]; + TraceFunc trace_func = args[4]; pctx->opt_level = opt_level; pctx->fallback_device = fallback_device; pctx->required_pass = std::move(required); pctx->disabled_pass = std::move(disabled); + pctx->trace_func = std::move(trace_func); *ret = pctx; }); diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index ac0f36c..516103f 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -116,7 +116,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, << pass_info->name << " with opt level: " << pass_info->opt_level; - + pass_ctx.Trace(mod, pass_info, true); // Execute the pass function and return a new module. IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports()); std::vector<std::pair<GlobalVar, Function> > updates; @@ -134,6 +134,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, for (const auto& pair : updates) { updated_mod->Add(pair.first, pair.second, true); } + pass_ctx.Trace(updated_mod, pass_info, true); return updated_mod; } diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index e02e917..d9e17a3 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -522,6 +522,36 @@ def test_print_ir(capfd): assert "Dumping the module IR" in out assert "multiply" in out +__TRACE_COUNTER__ = 0 + +def _tracer(module, info, is_before): + global __TRACE_COUNTER__ + if is_before: + __TRACE_COUNTER__ += 1 + +def test_print_debug_callback(): + global __TRACE_COUNTER__ + shape = (1, 2, 3) + tp = relay.TensorType(shape, "float32") + x = relay.var("x", tp) + y = relay.add(x, x) + y = relay.multiply(y, relay.const(2, "float32")) + func = relay.Function([x], y) + + seq = _transform.Sequential([ + relay.transform.InferType(), + relay.transform.FoldConstant(), + relay.transform.DeadCodeElimination() + ]) + + assert __TRACE_COUNTER__ == 0 + mod = relay.Module({"main": func}) + + with relay.build_config(opt_level=3, trace=_tracer): + mod = seq(mod) + + assert __TRACE_COUNTER__ == 4 + if __name__ == "__main__": pytest.main()