This is an automated email from the ASF dual-hosted git repository. andrewzhaoluo pushed a commit to branch aluo/build-funcs-inherit-passcontext in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 4801c02afa46eb0fe837a9e5556cc1d80cf3d1b2 Author: Andrew Zhao Luo <[email protected]> AuthorDate: Tue Jun 7 16:00:01 2022 -0700 initial commit --- python/tvm/auto_scheduler/measure.py | 2 +- python/tvm/autotvm/measure/measure_methods.py | 27 ++++++++++++++++++++++----- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 2a4a03bbe8..75f1116864 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -630,7 +630,7 @@ def _local_build_worker(inp_serialized, build_func, verbose): filename = os.path.join(dirname, "tmp_func." + build_func.output_format) try: - with transform.PassContext(): + with transform.PassContext().current(): func = build_module.build(sch, args, target=task.target) func.export_library(filename, build_func) # pylint: disable=broad-except diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index f582bd1974..7a398eb27d 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -505,10 +505,6 @@ def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option if not config.valid(): raise InstantiationError(config.errors) - opts = build_option or {} - if check_gpu: # Add verify pass to filter out invalid configs in advance. - opts["tir.add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))] - # if target is vta, we need to use vta build if ( hasattr(measure_input.target, "device_name") @@ -519,7 +515,28 @@ def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option func = vta.build(s, args, target_host=task.target_host) else: - with tvm.ir.transform.PassContext(config=opts): + current_pass_context: tvm.ir.transform.PassContext = ( + tvm.ir.transform.PassContext.current() + ) + current_config = dict(current_pass_context.config) + if build_option is not None: + current_config.update(build_option) + + if "tir.add_lower_pass" in current_config: + current_add_lower_pass = list(current_config["tir.add_lower_pass"]) + else: + current_add_lower_pass = [] + if check_gpu: + current_add_lower_pass.append((2, gpu_verify_pass(**check_gpu))) + current_config["tir.add_lower_pass"] = current_add_lower_pass + + with tvm.ir.transform.PassContext( + opt_level=current_pass_context.opt_level, + required_pass=current_pass_context.required_pass, + disabled_pass=current_pass_context.disabled_pass, + instruments=current_pass_context.instruments, + config=current_config, + ): func = build(s, args, target_host=task.target_host, runtime=runtime) return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)
