areusch commented on code in PR #11632:
URL: https://github.com/apache/tvm/pull/11632#discussion_r896020908
##########
tests/python/integration/test_tuning.py:
##########
@@ -180,6 +184,114 @@ def runner(target, dev):
run_test_with_all_multiprocessing(runner, target, dev)
[email protected]_targets("cuda", "opencl")
+def test_tuning_gpu_inherits_pass_context(target, dev):
+ """Autotvm tuner inherits PassContexts but also adds a gpu verification
pass by default.
+
+ Test that using PassContext inherits passes properly but also runs gpu
verification pass.
+ """
+ from tvm.tir.analysis import _ffi_api as _analysis_ffi_api
+
+ @pass_instrument
+ class PassInstrumentChecker:
+ """Pass Instrument that simply sees if it's been run."""
+
+ def __init__(self):
+ self.has_been_run = False
+
+ def run_after_pass(self, mod, info):
+ self.has_been_run = True
+
+ class GPUVerifyPassMocked:
+ """Context manager that mocks tir.analysis.verify_gpu_code meant
+ to verify the pass has been run. This is done by patching the ffi func
handles."""
+
+ FFI_FUNC_HANDLE = "tir.analysis.verify_gpu_code"
+ FUNC_NAME = "verify_gpu_code"
+
+ def __init__(self) -> None:
+ self.old_impl = tvm._ffi.get_global_func(self.FFI_FUNC_HANDLE)
+ self.has_been_run = False
+
+ def gpu_verify_pass_mocked(self):
+ """Get the replacement for the gpu verification pass."""
+
+ def _gpu_verify_pass_mocked(*args, **kwargs):
+ self.has_been_run = True
+ return self.old_impl(*args, **kwargs)
+
+ return _gpu_verify_pass_mocked
+
+ def __enter__(self):
+ tvm._ffi.register_func(
+ self.FFI_FUNC_HANDLE, self.gpu_verify_pass_mocked(),
override=True
+ )
+
+ # Also overwrite the python bindings
+ setattr(
+ _analysis_ffi_api, self.FUNC_NAME,
tvm._ffi.get_global_func(self.FFI_FUNC_HANDLE)
+ )
+
+ def __exit__(self, *args, **kwargs):
+ # Restore FFI status back to normal
+ tvm._ffi.register_func(self.FFI_FUNC_HANDLE, self.old_impl,
override=True)
+ setattr(_analysis_ffi_api, self.FUNC_NAME, self.old_impl)
+
+ class OverwrittenBuildFunc(measure_methods._WrappedBuildFunc):
+ """BuildFunc that mocks and patches as necessary to test proper passes
are run."""
+
+ def __call__(self, measure_input, tmp_dir, **kwargs):
+ instrument = PassInstrumentChecker()
+ mocked_pass_checker = GPUVerifyPassMocked()
+ with mocked_pass_checker:
+ with PassContext(instruments=[instrument]):
+ regular_result = super().__call__(measure_input, tmp_dir,
**kwargs)
+
+ # Check instrument has been run, meaning context was
inherited by builder
+ assert instrument.has_been_run
+
+ # But also check the gpu verification pass has been run
+ # (which was not in the inherited ctx)
+ assert mocked_pass_checker.has_been_run
+
+ return regular_result
+
+ class MockedLocalBuilder(measure_methods.LocalBuilder):
+ """As measure_methods.LocalBuilder but overwrites the PassContext for
testing."""
+
+ def __init__(
+ self,
+ timeout=10,
+ n_parallel=None,
+ build_kwargs=None,
+ build_func="default",
+ do_fork=False,
+ runtime=None,
+ ):
+ super().__init__(timeout, n_parallel, build_kwargs, build_func,
do_fork, runtime)
+ self.build_func = OverwrittenBuildFunc(tar.tar, runtime)
+
+ def runner(target, dev):
+ task, target = get_sample_task(target, None)
+ logging.info("task config space: %s", task.config_space)
+
+ # Note: we use the MockedLocalBuilder here instead of
autotvm.LocalBuilder()
+ measure_option = autotvm.measure_option(MockedLocalBuilder(),
autotvm.LocalRunner())
+
+ results = []
+
+ tuner = RandomTuner(task)
+ tuner.tune(
+ n_trial=1,
+ measure_option=measure_option,
+ callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),),
+ )
+
+ assert len(results) == 1
Review Comment:
want to check the pass also succeeded? i think if one of those asserts fail
we just get measure error here
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]