areusch commented on code in PR #11632:
URL: https://github.com/apache/tvm/pull/11632#discussion_r896020908


##########
tests/python/integration/test_tuning.py:
##########
@@ -180,6 +184,114 @@ def runner(target, dev):
     run_test_with_all_multiprocessing(runner, target, dev)
 
 
[email protected]_targets("cuda", "opencl")
+def test_tuning_gpu_inherits_pass_context(target, dev):
+    """Autotvm tuner inherits PassContexts but also adds a gpu verification 
pass by default.
+
+    Test that using PassContext inherits passes properly but also runs gpu 
verification pass.
+    """
+    from tvm.tir.analysis import _ffi_api as _analysis_ffi_api
+
+    @pass_instrument
+    class PassInstrumentChecker:
+        """Pass Instrument that simply sees if it's been run."""
+
+        def __init__(self):
+            self.has_been_run = False
+
+        def run_after_pass(self, mod, info):
+            self.has_been_run = True
+
+    class GPUVerifyPassMocked:
+        """Context manager that mocks tir.analysis.verify_gpu_code meant
+        to verify the pass has been run. This is done by patching the ffi func 
handles."""
+
+        FFI_FUNC_HANDLE = "tir.analysis.verify_gpu_code"
+        FUNC_NAME = "verify_gpu_code"
+
+        def __init__(self) -> None:
+            self.old_impl = tvm._ffi.get_global_func(self.FFI_FUNC_HANDLE)
+            self.has_been_run = False
+
+        def gpu_verify_pass_mocked(self):
+            """Get the replacement for the gpu verification pass."""
+
+            def _gpu_verify_pass_mocked(*args, **kwargs):
+                self.has_been_run = True
+                return self.old_impl(*args, **kwargs)
+
+            return _gpu_verify_pass_mocked
+
+        def __enter__(self):
+            tvm._ffi.register_func(
+                self.FFI_FUNC_HANDLE, self.gpu_verify_pass_mocked(), 
override=True
+            )
+
+            # Also overwrite the python bindings
+            setattr(
+                _analysis_ffi_api, self.FUNC_NAME, 
tvm._ffi.get_global_func(self.FFI_FUNC_HANDLE)
+            )
+
+        def __exit__(self, *args, **kwargs):
+            # Restore FFI status back to normal
+            tvm._ffi.register_func(self.FFI_FUNC_HANDLE, self.old_impl, 
override=True)
+            setattr(_analysis_ffi_api, self.FUNC_NAME, self.old_impl)
+
+    class OverwrittenBuildFunc(measure_methods._WrappedBuildFunc):
+        """BuildFunc that mocks and patches as necessary to test proper passes 
are run."""
+
+        def __call__(self, measure_input, tmp_dir, **kwargs):
+            instrument = PassInstrumentChecker()
+            mocked_pass_checker = GPUVerifyPassMocked()
+            with mocked_pass_checker:
+                with PassContext(instruments=[instrument]):
+                    regular_result = super().__call__(measure_input, tmp_dir, 
**kwargs)
+
+                    # Check instrument has been run, meaning context was 
inherited by builder
+                    assert instrument.has_been_run
+
+                    # But also check the gpu verification pass has been run
+                    # (which was not in the inherited ctx)
+                    assert mocked_pass_checker.has_been_run
+
+                    return regular_result
+
+    class MockedLocalBuilder(measure_methods.LocalBuilder):
+        """As measure_methods.LocalBuilder but overwrites the PassContext for 
testing."""
+
+        def __init__(
+            self,
+            timeout=10,
+            n_parallel=None,
+            build_kwargs=None,
+            build_func="default",
+            do_fork=False,
+            runtime=None,
+        ):
+            super().__init__(timeout, n_parallel, build_kwargs, build_func, 
do_fork, runtime)
+            self.build_func = OverwrittenBuildFunc(tar.tar, runtime)
+
+    def runner(target, dev):
+        task, target = get_sample_task(target, None)
+        logging.info("task config space: %s", task.config_space)
+
+        # Note: we use the MockedLocalBuilder here instead of 
autotvm.LocalBuilder()
+        measure_option = autotvm.measure_option(MockedLocalBuilder(), 
autotvm.LocalRunner())
+
+        results = []
+
+        tuner = RandomTuner(task)
+        tuner.tune(
+            n_trial=1,
+            measure_option=measure_option,
+            callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),),
+        )
+
+        assert len(results) == 1

Review Comment:
   want to check the pass also succeeded? i think if one of those asserts fail 
we just get measure error here



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to