This is an automated email from the ASF dual-hosted git repository.

sanirudh pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 5b1fa29838 [Unity][VM] Allow `pipeline=None` in `relax.build` (#16246)
5b1fa29838 is described below

commit 5b1fa298388cca09235479fba57871e9a056f910
Author: Junru Shao <[email protected]>
AuthorDate: Fri Dec 15 04:25:07 2023 -0800

    [Unity][VM] Allow `pipeline=None` in `relax.build` (#16246)
    
    In sophisticated usecases, developers may want full control of the relax
    compilation pipeline. A typical workflow looks like:
    - Step 1. Apply a custom relax.pipeline to IRModule
    - Step 2. Invoke `relax.build` to export relax Executable but without any 
additional passes;
    - Step 3. Manipulate the relax Executable.
    
    In the three steps above, both Step 1 and 3 are already supported, and
    this PR aims to enable Step 2, which allows `relax.build` to apply a
    `None` pipeline which contains no passes.
    
    Note: as an advanced behavior, this is enabled only when explicitly
    setting `pipeline=None` and thus backward compatibility is fully
    preserved.
---
 python/tvm/relax/pipeline.py |  2 +-
 python/tvm/relax/vm_build.py | 43 ++++++++++++++++++++++++++-----------------
 2 files changed, 27 insertions(+), 18 deletions(-)

diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py
index a4ba3315b8..ebcbd2d609 100644
--- a/python/tvm/relax/pipeline.py
+++ b/python/tvm/relax/pipeline.py
@@ -95,7 +95,7 @@ def default_build_pipeline():
                 transform.AttachGlobalSymbol(),
             ],
         )
-        mod = seq(mod._move())  # pylint: disable=protected-access
+        mod = seq(mod)
         return mod
 
     return _pipeline
diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py
index 7a7649c449..e4d1fefbe7 100644
--- a/python/tvm/relax/vm_build.py
+++ b/python/tvm/relax/vm_build.py
@@ -176,7 +176,7 @@ def _vmcodegen(
         return _ffi_api.VMCodeGen(builder, mod)  # type:ignore
     if exec_mode == "compiled":
         return _ffi_api.VMTIRCodeGen(builder, mod)  # type: ignore
-    raise ValueError("Unknown exec_mode %s" % exec_mode)
+    raise ValueError(f"Unknown exec_mode {exec_mode}")
 
 
 def _autodetect_system_lib_req(target: tvm.target.Target, system_lib):
@@ -248,7 +248,7 @@ def build(
     mod: tvm.IRModule,
     target: Union[str, tvm.target.Target],
     params: Optional[Dict[str, list]] = None,
-    pipeline: str = "default_build",
+    pipeline: Union[None, str, tvm.transform.Pass] = "default_build",
     exec_mode: str = "bytecode",
     *,
     system_lib: Optional[bool] = None,
@@ -305,26 +305,35 @@ def build(
         target = tvm.target.Target("llvm", host="llvm")
         ex = relax.build(mod, target)
     """
-    if isinstance(target, str):
-        target = tvm.target.Target(target)
 
-    new_mod = relax.get_pipeline(pipeline)(mod)
-    # Extract external runtime modules if exist.
-    attrs = dict(mod.attrs) if mod.attrs else {}
+    def _extract_attrs(mod: tvm.IRModule):
+        attrs = dict(mod.attrs) if mod.attrs else {}
+        ext_libs = attrs.get("external_mods", [])
+        constants = attrs.get("const_name_to_constant", {})
+        return ext_libs, constants
 
-    ext_libs = attrs.get("external_mods", [])
-    constants = attrs.get("const_name_to_constant", {})
+    if isinstance(target, str):
+        target = tvm.target.Target(target)
+    if not params:
+        params = {}
 
-    if params is not None:
-        params.update(dict(constants))
-    else:
-        params = constants
+    ext_libs, constants = _extract_attrs(mod)
+    params.update(dict(constants))
 
-    # builder collects the executable
+    if pipeline is not None:
+        if isinstance(pipeline, str):
+            pipeline = relax.get_pipeline(pipeline)
+        mod = pipeline(mod)
     builder = relax.ExecBuilder()
-    leftover_mod = _vmcodegen(builder, new_mod, exec_mode=exec_mode)
-    tir_mod = _filter_tir(leftover_mod)
-    return _vmlink(builder, target, tir_mod, ext_libs, params, 
system_lib=system_lib)
+    mod = _vmcodegen(builder, mod, exec_mode)
+    return _vmlink(
+        builder=builder,
+        target=target,
+        tir_mod=_filter_tir(mod),
+        ext_libs=ext_libs,
+        params=params,
+        system_lib=system_lib,
+    )
 
 
 def _filter_tir(mod: tvm.IRModule) -> tvm.IRModule:

Reply via email to