This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new d5daa9806d [Unity] Allow Customized Pipeline in `relax.build` (#16121)
d5daa9806d is described below
commit d5daa9806dc6dd1d2d8cdedcd142cfa89f22eaef
Author: Junru Shao <[email protected]>
AuthorDate: Tue Nov 14 00:20:05 2023 -0800
[Unity] Allow Customized Pipeline in `relax.build` (#16121)
The existing `relax.build` method assumes the compilation follows a
fixed set of passes to be used. With the introduction of Relax pipeline
system, one could effectively manage which passes to use during
lowering. This PR generalizes this approach by further allowing what to
use during compilation.
---
python/tvm/relax/pipeline.py | 32 +++++++++++++++++++++++++++++++-
python/tvm/relax/vm_build.py | 30 +++++++-----------------------
2 files changed, 38 insertions(+), 24 deletions(-)
diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py
index 367c1ede0e..a4ba3315b8 100644
--- a/python/tvm/relax/pipeline.py
+++ b/python/tvm/relax/pipeline.py
@@ -74,8 +74,38 @@ def zero_pipeline(*, enable_warning: bool = False):
return f_zero_pipeline
+def default_build_pipeline():
+ """The default compilation pipeline used in relax.build"""
+
+ @tvm.transform.module_pass(opt_level=0)
+ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) ->
tvm.ir.IRModule:
+ seq = tvm.transform.Sequential(
+ [
+ transform.LegalizeOps(),
+ transform.RewriteDataflowReshape(),
+ transform.ToNonDataflow(),
+ transform.RemovePurityChecking(),
+ transform.CallTIRRewrite(),
+ transform.StaticPlanBlockMemory(),
+ transform.RewriteCUDAGraph(),
+ transform.LowerAllocTensor(),
+ transform.KillAfterLastUse(),
+ transform.VMBuiltinLower(),
+ transform.VMShapeLower(),
+ transform.AttachGlobalSymbol(),
+ ],
+ )
+ mod = seq(mod._move()) # pylint: disable=protected-access
+ return mod
+
+ return _pipeline
+
+
# global map of pre-built pipelines
-PIPELINE_MAP = {"zero": zero_pipeline}
+PIPELINE_MAP = {
+ "zero": zero_pipeline,
+ "default_build": default_build_pipeline,
+}
def get_pipeline(name: str = "zero", **kwargs) -> tvm.transform.Pass:
diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py
index a54c0154fc..7a7649c449 100644
--- a/python/tvm/relax/vm_build.py
+++ b/python/tvm/relax/vm_build.py
@@ -16,13 +16,11 @@
# under the License.
# pylint: disable=invalid-name, no-member
"""VM build logics"""
-from typing import List, Optional, Union, Dict, Any
+from typing import Any, Dict, List, Optional, Union
import tvm
from tvm import relax
-
from tvm.contrib import utils as _utils
-
from tvm.ir.module import IRModule
from tvm.tir.function import PrimFunc
@@ -80,6 +78,7 @@ class Executable:
rt_mod = ex.jit()
vm = tvm.relax.VirtualMachine(rt_mod, tvm.cuda())
"""
+
# TODO(tvm-team): Update runtime.Module interfac
# to query these properties as bitmask.
def _not_runnable(x):
@@ -249,6 +248,7 @@ def build(
mod: tvm.IRModule,
target: Union[str, tvm.target.Target],
params: Optional[Dict[str, list]] = None,
+ pipeline: str = "default_build",
exec_mode: str = "bytecode",
*,
system_lib: Optional[bool] = None,
@@ -274,6 +274,9 @@ def build(
params: Optional[Dict[str, list]]
Parameters for the input IRModule that will be bound.
+ pipeline : str = "default_build"
+ The compilation pipeline to use.
+
exec_mode: {"bytecode", "compiled"}
The execution mode.
@@ -305,26 +308,7 @@ def build(
if isinstance(target, str):
target = tvm.target.Target(target)
- lowering_passes = tvm.transform.Sequential(
- [
- relax.transform.LegalizeOps(),
- relax.transform.RewriteDataflowReshape(),
- relax.transform.ToNonDataflow(),
- relax.transform.RemovePurityChecking(),
- relax.transform.CallTIRRewrite(),
- relax.transform.StaticPlanBlockMemory(),
- relax.transform.RewriteCUDAGraph(),
- relax.transform.LowerAllocTensor(),
- relax.transform.KillAfterLastUse(),
- relax.transform.VMBuiltinLower(),
- relax.transform.VMShapeLower(),
- relax.transform.AttachGlobalSymbol(),
- ],
- name="relax.lower",
- )
-
- new_mod = lowering_passes(mod)
-
+ new_mod = relax.get_pipeline(pipeline)(mod)
# Extract external runtime modules if exist.
attrs = dict(mod.attrs) if mod.attrs else {}