zackcquic commented on a change in pull request #7952:
URL: https://github.com/apache/tvm/pull/7952#discussion_r639027780
##########
File path: tests/python/relay/test_pass_instrument.py
##########
@@ -14,77 +14,91 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+""" Instrument test cases.
+"""
+import pytest
import tvm
import tvm.relay
from tvm.relay import op
from tvm.ir.instrument import PassTimingInstrument, pass_instrument
-def test_pass_timing_instrument():
+def get_test_model():
x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"]
e1 = op.add(x, y)
e2 = op.subtract(x, z)
e3 = op.multiply(e1, e1 / e2)
- mod = tvm.IRModule.from_expr(e3 + e2)
+ return tvm.IRModule.from_expr(e3 + e2)
+
+def test_pass_timing_instrument():
pass_timing = PassTimingInstrument()
- with tvm.transform.PassContext(instruments=[pass_timing]):
- mod = tvm.relay.transform.AnnotateSpans()(mod)
- mod = tvm.relay.transform.ToANormalForm()(mod)
- mod = tvm.relay.transform.InferType()(mod)
- profiles = pass_timing.render()
- assert "AnnotateSpans" in profiles
- assert "ToANormalForm" in profiles
- assert "InferType" in profiles
+ # Override current PassContext's instruments
+ tvm.transform.PassContext.current().override_instruments([pass_timing])
+ mod = get_test_model()
+ mod = tvm.relay.transform.AnnotateSpans()(mod)
+ mod = tvm.relay.transform.ToANormalForm()(mod)
+ mod = tvm.relay.transform.InferType()(mod)
-def test_custom_instrument(capsys):
- x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"]
- e1 = op.add(x, y)
- e2 = op.subtract(x, z)
- e3 = op.multiply(e1, e1 / e2)
- mod = tvm.IRModule.from_expr(e3 + e2)
+ profiles = pass_timing.render()
+ assert "AnnotateSpans" in profiles
+ assert "ToANormalForm" in profiles
+ assert "InferType" in profiles
+
+ # Reset current PassContext's instruments to None
+ tvm.transform.PassContext.current().override_instruments(None)
+ mod = get_test_model()
+ mod = tvm.relay.transform.AnnotateSpans()(mod)
+ mod = tvm.relay.transform.ToANormalForm()(mod)
+ mod = tvm.relay.transform.InferType()(mod)
+
+ profiles = pass_timing.render()
+ assert profiles == ""
+
+
+def test_custom_instrument(capsys):
@pass_instrument
class MyTest:
- def set_up(self):
- print("set up")
+ def enter_pass_ctx(self):
+ print("enter ctx")
Review comment:
Done
print and capsys removed with events
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]