This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0c43ab2be0 [BugFix][MetaSchedule] Fix `compile_relax` to apply
`MetaScheduleApplyDatabase` after `FuseOps` (#19385)
0c43ab2be0 is described below
commit 0c43ab2be0a33828edd02bb0e5cc846851685cec
Author: Soowon Jeong <[email protected]>
AuthorDate: Sun Apr 12 03:27:26 2026 +0900
[BugFix][MetaSchedule] Fix `compile_relax` to apply
`MetaScheduleApplyDatabase` after `FuseOps` (#19385)
## Problem
`compile_relax` applied `MetaScheduleApplyDatabase` to the *unfused*
IRModule, then called `relax_build` which re-ran the full pipeline
including `FuseOps` + `FuseTIR` + DLight. This caused two compounding
issues:
1. **Granularity mismatch**: `extract_tasks` (called inside
`tune_relax`) runs *after* `FuseOps`+`FuseTIR`, so database keys
correspond to *fused* TIR functions. Applying the database to unfused
functions produces no matches.
2. **DLight failure on MetaSchedule-scheduled TIR**: After
`FuseOps`+`FuseTIR` re-fused the module, the DLight Matmul rule
attempted to schedule fused TIR that now contained MetaSchedule
cache-write stages, and failed:
```
AssertionError: There are some consumers of the cache-write stage that are
not properly inlined.
```
## Fix
Build a custom pipeline inside `compile_relax` with the correct
ordering:
1. Library dispatch + `LegalizeOps` + `FuseOps` + `FuseTIR` — same
preparation as `extract_tasks`, ensuring database keys match
2. `MetaScheduleApplyDatabase` — applied to fused TIR
3. DLight `ApplyDefaultSchedule` — fallback for functions not covered by
the database (GPU targets only)
4. VM lowering passes (`dataflow_lower_passes` + `finalize_passes`)
Two correctness fixes addressed during review:
- `backend_specific` is now gated on `target.kind.name == "cuda"` before
attempting the import, preventing CUDA-specific passes (e.g.
`RewriteCUDAGraph`) from being applied to non-CUDA targets.
- DLight rules are selected based on `is_gpu_target` (`cuda`, `opencl`,
`metal`, `vulkan`, `rocm`); non-GPU targets skip the DLight step since
they do not require explicit thread-binding scheduling.
---
.../tvm/s_tir/meta_schedule/relax_integration.py | 72 +++++++++++++++++++++-
.../relax/test_meta_schedule_relax_integration.py | 50 +++++++++++++++
2 files changed, 120 insertions(+), 2 deletions(-)
diff --git a/python/tvm/s_tir/meta_schedule/relax_integration.py
b/python/tvm/s_tir/meta_schedule/relax_integration.py
index 08fe1a434d..0cd19b0aad 100644
--- a/python/tvm/s_tir/meta_schedule/relax_integration.py
+++ b/python/tvm/s_tir/meta_schedule/relax_integration.py
@@ -407,8 +407,12 @@ def compile_relax(
The built runtime module or vm VMExecutable for the given relax
workload.
"""
# pylint: disable=import-outside-toplevel
+ import tvm
+ from tvm import relax
from tvm.relax import build as relax_build
+ from tvm.relax import pipeline as relax_pipeline_mod
from tvm.relax.transform import BindParams, MetaScheduleApplyDatabase
+ from tvm.s_tir import dlight as dl
# pylint: enable=import-outside-toplevel
if not isinstance(target, Target):
@@ -416,7 +420,71 @@ def compile_relax(
if params:
mod = BindParams("main", params)(mod)
+ # Build a pipeline with the correct ordering:
+ # 1. library_dispatch + LegalizeOps + FuseOps + FuseTIR
+ # (same preparation as extract_tasks, so database keys match)
+ # 2. MetaScheduleApplyDatabase — replaces tuned fused-TIR functions
+ # 3. DLight fallback — schedules remaining untuned functions
+ # 4. dataflow_lower + finalize passes
+ #
+ # Applying MetaScheduleApplyDatabase BEFORE FuseOps (the original bug)
+ # caused DLight.Matmul to fail on cache-write stages embedded in fused TIR.
+ #
+ # All pass lists are obtained from relax.pipeline.*_passes(target) so that
+ # target-specific helpers (dispatch, finalize) are shared with the default
+ # pipeline rather than duplicated here.
+ try:
+ dispatch_passes = relax_pipeline_mod.library_dispatch_passes(target)
+ except (ValueError, AttributeError):
+ dispatch_passes = []
+
+ try:
+ lower_passes = relax_pipeline_mod.dataflow_lower_passes(target)
+ finalize_passes = relax_pipeline_mod.finalize_passes(target)
+ except (ValueError, AttributeError):
+ # Fallback for targets not yet registered in the pipeline dispatcher
+ lower_passes = [
+ relax.transform.RewriteDataflowReshape(),
+ relax.transform.ToNonDataflow(),
+ relax.transform.RemovePurityChecking(),
+ relax.transform.CallTIRRewrite(),
+ ]
+ finalize_passes = [
+ relax.transform.StaticPlanBlockMemory(),
+ relax.transform.LowerAllocTensor(),
+ relax.transform.KillAfterLastUse(),
+ relax.transform.LowerRuntimeBuiltin(),
+ relax.transform.ComputePrimValue(),
+ relax.transform.VMShapeLower(),
+ relax.transform.AttachGlobalSymbol(),
+ ]
+
+ is_gpu_target = relax_pipeline_mod.BackendDispatcher.is_gpu_target(target)
+
+ @tvm.transform.module_pass(opt_level=3)
+ def _ms_pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) ->
tvm.ir.IRModule:
+ fuse_seq = dispatch_passes + [
+ relax.transform.LegalizeOps(enable_warning=enable_warning),
+ relax.transform.AnnotateTIROpPattern(),
+ relax.transform.FoldConstant(),
+ relax.transform.FuseOps(),
+ relax.transform.FuseTIR(),
+ ]
+ mod = tvm.transform.Sequential(fuse_seq)(mod)
+ mod = MetaScheduleApplyDatabase(enable_warning=enable_warning)(mod)
+ # DLight handles functions not covered by the database.
+ # GPU rules apply only for GPU targets.
+ if is_gpu_target:
+ mod = dl.ApplyDefaultSchedule(
+ dl.gpu.Matmul(),
+ dl.gpu.GEMV(),
+ dl.gpu.Reduction(),
+ dl.gpu.GeneralReduction(),
+ dl.gpu.Fallback(),
+ )(mod)
+ mod = tvm.transform.Sequential(lower_passes + finalize_passes)(mod)
+ return mod
+
with target, database, PassContext(opt_level=3):
- relax_mod =
MetaScheduleApplyDatabase(enable_warning=enable_warning)(mod)
- relax_ex = relax_build(relax_mod, target=target)
+ relax_ex = relax_build(mod, target=target, relax_pipeline=_ms_pipeline)
return relax_ex
diff --git a/tests/python/relax/test_meta_schedule_relax_integration.py
b/tests/python/relax/test_meta_schedule_relax_integration.py
index 72699a2c6f..c28b8c444b 100644
--- a/tests/python/relax/test_meta_schedule_relax_integration.py
+++ b/tests/python/relax/test_meta_schedule_relax_integration.py
@@ -17,9 +17,17 @@
# ruff: noqa: E501, F401, F841
"""Integration test for MetaSchedule"""
+import tempfile
+
+import numpy as np
+import pytest
+
import tvm
import tvm.testing
from tvm import relax
+from tvm.runtime import tensor as tvm_tensor
+from tvm.runtime import cpu as tvm_cpu
+from tvm.runtime.vm import VirtualMachine
from tvm.s_tir import meta_schedule as ms
from tvm.script import ir as I
from tvm.script import relax as R
@@ -73,5 +81,47 @@ def test_extracting_tasks():
assert len(extracted_tasks) == count
+def test_compile_relax_with_database():
+ """End-to-end test: tune with MetaSchedule then compile_relax with the
database.
+
+ Verifies that the pipeline ordering in compile_relax is correct: tasks are
+ extracted and tuned against fused-TIR keys, and compile_relax produces
those
+ same keys (by running LegalizeOps + FuseOps + FuseTIR before applying the
+ database), so the scheduled kernels are actually picked up.
+ """
+ target = tvm.target.Target({"kind": "llvm", "num-cores": 1})
+
+ # Prepare the fused module whose TIR keys will populate the database.
+ fused_mod = Module0
+ fused_mod = relax.transform.LegalizeOps()(fused_mod)
+ fused_mod = relax.transform.AnnotateTIROpPattern()(fused_mod)
+ fused_mod = relax.transform.FuseOps()(fused_mod)
+ fused_mod = relax.transform.FoldConstant()(fused_mod)
+ fused_mod = relax.transform.FuseTIR()(fused_mod)
+
+ with tempfile.TemporaryDirectory() as work_dir:
+ database = ms.relax_integration.tune_relax(
+ fused_mod,
+ params={},
+ target=target,
+ work_dir=work_dir,
+ max_trials_global=4,
+ )
+ # compile_relax takes the raw module and builds the fused-TIR pipeline
+ # internally; the database keys must therefore match the ones above.
+ exe = ms.relax_integration.compile_relax(
+ database=database,
+ mod=Module0,
+ target=target,
+ params=None,
+ )
+
+ dev = tvm_cpu()
+ vm = VirtualMachine(exe.jit(), dev)
+ data = tvm_tensor(np.zeros((1, 8, 8, 4), dtype="int32"), device=dev)
+ result = vm["main"](data)
+ assert result.numpy().shape == (1, 8, 8, 4)
+
+
if __name__ == "__main__":
tvm.testing.main()