zackcquic commented on a change in pull request #7952:
URL: https://github.com/apache/tvm/pull/7952#discussion_r639027780



##########
File path: tests/python/relay/test_pass_instrument.py
##########
@@ -14,77 +14,91 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+""" Instrument test cases.
+"""
+import pytest
 import tvm
 import tvm.relay
 from tvm.relay import op
 from tvm.ir.instrument import PassTimingInstrument, pass_instrument
 
 
-def test_pass_timing_instrument():
+def get_test_model():
     x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"]
     e1 = op.add(x, y)
     e2 = op.subtract(x, z)
     e3 = op.multiply(e1, e1 / e2)
-    mod = tvm.IRModule.from_expr(e3 + e2)
+    return tvm.IRModule.from_expr(e3 + e2)
+
 
+def test_pass_timing_instrument():
     pass_timing = PassTimingInstrument()
-    with tvm.transform.PassContext(instruments=[pass_timing]):
-        mod = tvm.relay.transform.AnnotateSpans()(mod)
-        mod = tvm.relay.transform.ToANormalForm()(mod)
-        mod = tvm.relay.transform.InferType()(mod)
 
-        profiles = pass_timing.render()
-        assert "AnnotateSpans" in profiles
-        assert "ToANormalForm" in profiles
-        assert "InferType" in profiles
+    # Override current PassContext's instruments
+    tvm.transform.PassContext.current().override_instruments([pass_timing])
 
+    mod = get_test_model()
+    mod = tvm.relay.transform.AnnotateSpans()(mod)
+    mod = tvm.relay.transform.ToANormalForm()(mod)
+    mod = tvm.relay.transform.InferType()(mod)
 
-def test_custom_instrument(capsys):
-    x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"]
-    e1 = op.add(x, y)
-    e2 = op.subtract(x, z)
-    e3 = op.multiply(e1, e1 / e2)
-    mod = tvm.IRModule.from_expr(e3 + e2)
+    profiles = pass_timing.render()
+    assert "AnnotateSpans" in profiles
+    assert "ToANormalForm" in profiles
+    assert "InferType" in profiles
+
+    # Reset current PassContext's instruments to None
+    tvm.transform.PassContext.current().override_instruments(None)
 
+    mod = get_test_model()
+    mod = tvm.relay.transform.AnnotateSpans()(mod)
+    mod = tvm.relay.transform.ToANormalForm()(mod)
+    mod = tvm.relay.transform.InferType()(mod)
+
+    profiles = pass_timing.render()
+    assert profiles == ""
+
+
+def test_custom_instrument(capsys):
     @pass_instrument
     class MyTest:
-        def set_up(self):
-            print("set up")
+        def enter_pass_ctx(self):
+            print("enter ctx")

Review comment:
       Done
   print and capsys removed with events




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to