areusch commented on a change in pull request #7952:
URL: https://github.com/apache/tvm/pull/7952#discussion_r638946971



##########
File path: src/ir/transform.cc
##########
@@ -164,34 +167,57 @@ void PassContext::RegisterConfigOption(const char* key, 
uint32_t value_type_inde
 
 PassContext PassContext::Create() { return 
PassContext(make_object<PassContextNode>()); }
 
-void PassContext::InstrumentSetUp() const {
+void PassContext::InstrumentEnterPassContext() {
   auto pass_ctx_node = this->operator->();
   if (pass_ctx_node->instruments.defined()) {
-    for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
-      pi->SetUp();
+    try {
+      for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
+        pi->EnterPassContext();
+      }
+    } catch (const Error& e) {
+      LOG(INFO) << "Pass instrumentation entering pass context failed.";
+      LOG(INFO) << "Disable pass instrumentation.";
+      pass_ctx_node->instruments.clear();
+      throw e;
     }
   }
 }
 
-void PassContext::InstrumentTearDown() const {
+void PassContext::InstrumentExitPassContext() {
   auto pass_ctx_node = this->operator->();
   if (pass_ctx_node->instruments.defined()) {
-    for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
-      pi->TearDown();
+    try {
+      for (instrument::PassInstrument pi : pass_ctx_node->instruments) {
+        pi->ExitPassContext();
+      }
+    } catch (const Error& e) {
+      LOG(INFO) << "Pass instrumentation exiting pass context failed.";
+      pass_ctx_node->instruments.clear();
+      throw e;
     }
   }
 }
 
 bool PassContext::InstrumentBeforePass(const IRModule& ir_module, const 
PassInfo& pass_info) const {
   auto pass_ctx_node = this->operator->();
-  if (pass_ctx_node->instruments.defined()) {
+  if (!pass_ctx_node->instruments.defined()) {
+    return true;
+  }
+
+  const bool pass_required = PassArrayContains(pass_ctx_node->required_pass, 
pass_info->name);
+  bool should_run = true;
+  if (!pass_required) {
+    const Array<instrument::PassInstrument>& instruments = 
pass_ctx_node->instruments;
+    should_run &= std::all_of(instruments.begin(), instruments.end(),

Review comment:
       this is a bit counter to the stl, but I actually prefer if code like 
this is just explicitly written out. then it's clear what happens if any of 
`pi->ShouldRun` returns early.
   
   Here, it might be slightly better to explicitly invoke all ShouldRun(), even 
if one returns False. Then, passes that simply implement some type of logging 
or timing can still include entries for those passes that didn't run. i'm happy 
to be swayed either way on this, but it seems more useful to me that way.

##########
File path: tests/python/relay/test_pass_instrument.py
##########
@@ -168,3 +169,329 @@ def run_after_pass(self, mod, info):
     # Out of pass context scope, should be reset
     assert passes_counter.run_before_count == 0
     assert passes_counter.run_after_count == 0
+
+
+def test_enter_pass_ctx_exception(capsys):
+    @pass_instrument
+    class PI:
+        def __init__(self, id):
+            self.id = id
+
+        def enter_pass_ctx(self):
+            print(self.id + " enter ctx")
+
+        def exit_pass_ctx(self):
+            print(self.id + " exit ctx")
+
+    @pass_instrument
+    class PIBroken(PI):
+        def __init__(self, id):
+            super().__init__(id)
+
+        def enter_pass_ctx(self):
+            print(self.id + " enter ctx")
+            raise RuntimeError("Just a dummy error")
+
+    pass_ctx = tvm.transform.PassContext(instruments=[PI("%1"), 
PIBroken("%2"), PI("%3")])
+    with pytest.raises(tvm.error.TVMError):
+        with pass_ctx:
+            pass
+
+    assert "%1 enter ctx\n" "%2 enter ctx\n" == capsys.readouterr().out

Review comment:
       should this also contain "%1 exit ctx\n"?

##########
File path: tests/python/relay/test_pass_instrument.py
##########
@@ -14,77 +14,91 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+""" Instrument test cases.
+"""
+import pytest
 import tvm
 import tvm.relay
 from tvm.relay import op
 from tvm.ir.instrument import PassTimingInstrument, pass_instrument
 
 
-def test_pass_timing_instrument():
+def get_test_model():
     x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"]
     e1 = op.add(x, y)
     e2 = op.subtract(x, z)
     e3 = op.multiply(e1, e1 / e2)
-    mod = tvm.IRModule.from_expr(e3 + e2)
+    return tvm.IRModule.from_expr(e3 + e2)
+
 
+def test_pass_timing_instrument():
     pass_timing = PassTimingInstrument()
-    with tvm.transform.PassContext(instruments=[pass_timing]):
-        mod = tvm.relay.transform.AnnotateSpans()(mod)
-        mod = tvm.relay.transform.ToANormalForm()(mod)
-        mod = tvm.relay.transform.InferType()(mod)
 
-        profiles = pass_timing.render()
-        assert "AnnotateSpans" in profiles
-        assert "ToANormalForm" in profiles
-        assert "InferType" in profiles
+    # Override current PassContext's instruments
+    tvm.transform.PassContext.current().override_instruments([pass_timing])
 
+    mod = get_test_model()
+    mod = tvm.relay.transform.AnnotateSpans()(mod)
+    mod = tvm.relay.transform.ToANormalForm()(mod)
+    mod = tvm.relay.transform.InferType()(mod)
 
-def test_custom_instrument(capsys):
-    x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"]
-    e1 = op.add(x, y)
-    e2 = op.subtract(x, z)
-    e3 = op.multiply(e1, e1 / e2)
-    mod = tvm.IRModule.from_expr(e3 + e2)
+    profiles = pass_timing.render()
+    assert "AnnotateSpans" in profiles
+    assert "ToANormalForm" in profiles
+    assert "InferType" in profiles
+
+    # Reset current PassContext's instruments to None
+    tvm.transform.PassContext.current().override_instruments(None)
 
+    mod = get_test_model()
+    mod = tvm.relay.transform.AnnotateSpans()(mod)
+    mod = tvm.relay.transform.ToANormalForm()(mod)
+    mod = tvm.relay.transform.InferType()(mod)
+
+    profiles = pass_timing.render()
+    assert profiles == ""
+
+
+def test_custom_instrument(capsys):
     @pass_instrument
     class MyTest:
-        def set_up(self):
-            print("set up")
+        def enter_pass_ctx(self):
+            print("enter ctx")

Review comment:
       rather than print, can you call a function like:
   
   ```
   def __init__(self):
       self._events = []
   
   def _log_event(self, msg):
        print(msg) # or maybe logging.info
        self._events.append(msg)
   ```
   
   then just assert on `self._events`? otherwise print() in the middle of e.g. 
FFI could cause this test to fail (not that it would be committed, but could 
make debugging harder).

##########
File path: tests/python/relay/test_pass_instrument.py
##########
@@ -205,6 +205,18 @@ def enter_pass_ctx(self):
     assert cur_pass_ctx.instruments == None
 
 
+def test_enter_pass_ctx_exception_global(capsys):
+    @pass_instrument
+    class PIBroken:
+        def enter_pass_ctx(self):
+            raise RuntimeError("Just a dummy error")
+
+    cur_pass_ctx = tvm.transform.PassContext.current()
+    with pytest.raises(tvm.error.TVMError):

Review comment:
       i think you can do one better and assert the specific TVMError, 
something like:
   
   ```
   with pytest.raises(tvm.error.TVMError) as cm:
     cur_pass_ctx.override_instruments([PIBroken()])
     assert "Just a dummy error" in str(cm.exception)
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to