This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0c43ab2be0 [BugFix][MetaSchedule] Fix `compile_relax` to apply 
`MetaScheduleApplyDatabase` after `FuseOps` (#19385)
0c43ab2be0 is described below

commit 0c43ab2be0a33828edd02bb0e5cc846851685cec
Author: Soowon Jeong <[email protected]>
AuthorDate: Sun Apr 12 03:27:26 2026 +0900

    [BugFix][MetaSchedule] Fix `compile_relax` to apply 
`MetaScheduleApplyDatabase` after `FuseOps` (#19385)
    
    ## Problem
    
    `compile_relax` applied `MetaScheduleApplyDatabase` to the *unfused*
    IRModule, then called `relax_build` which re-ran the full pipeline
    including `FuseOps` + `FuseTIR` + DLight. This caused two compounding
    issues:
    
    1. **Granularity mismatch**: `extract_tasks` (called inside
    `tune_relax`) runs *after* `FuseOps`+`FuseTIR`, so database keys
    correspond to *fused* TIR functions. Applying the database to unfused
    functions produces no matches.
    
    2. **DLight failure on MetaSchedule-scheduled TIR**: After
    `FuseOps`+`FuseTIR` re-fused the module, the DLight Matmul rule
    attempted to schedule fused TIR that now contained MetaSchedule
    cache-write stages, and failed:
    
    ```
    AssertionError: There are some consumers of the cache-write stage that are 
not properly inlined.
    ```
    
    ## Fix
    
    Build a custom pipeline inside `compile_relax` with the correct
    ordering:
    
    1. Library dispatch + `LegalizeOps` + `FuseOps` + `FuseTIR` — same
    preparation as `extract_tasks`, ensuring database keys match
    2. `MetaScheduleApplyDatabase` — applied to fused TIR
    3. DLight `ApplyDefaultSchedule` — fallback for functions not covered by
    the database (GPU targets only)
    4. VM lowering passes (`dataflow_lower_passes` + `finalize_passes`)
    
    Two correctness fixes addressed during review:
    - `backend_specific` is now gated on `target.kind.name == "cuda"` before
    attempting the import, preventing CUDA-specific passes (e.g.
    `RewriteCUDAGraph`) from being applied to non-CUDA targets.
    - DLight rules are selected based on `is_gpu_target` (`cuda`, `opencl`,
    `metal`, `vulkan`, `rocm`); non-GPU targets skip the DLight step since
    they do not require explicit thread-binding scheduling.
---
 .../tvm/s_tir/meta_schedule/relax_integration.py   | 72 +++++++++++++++++++++-
 .../relax/test_meta_schedule_relax_integration.py  | 50 +++++++++++++++
 2 files changed, 120 insertions(+), 2 deletions(-)

diff --git a/python/tvm/s_tir/meta_schedule/relax_integration.py 
b/python/tvm/s_tir/meta_schedule/relax_integration.py
index 08fe1a434d..0cd19b0aad 100644
--- a/python/tvm/s_tir/meta_schedule/relax_integration.py
+++ b/python/tvm/s_tir/meta_schedule/relax_integration.py
@@ -407,8 +407,12 @@ def compile_relax(
         The built runtime module or vm VMExecutable for the given relax 
workload.
     """
     # pylint: disable=import-outside-toplevel
+    import tvm
+    from tvm import relax
     from tvm.relax import build as relax_build
+    from tvm.relax import pipeline as relax_pipeline_mod
     from tvm.relax.transform import BindParams, MetaScheduleApplyDatabase
+    from tvm.s_tir import dlight as dl
 
     # pylint: enable=import-outside-toplevel
     if not isinstance(target, Target):
@@ -416,7 +420,71 @@ def compile_relax(
     if params:
         mod = BindParams("main", params)(mod)
 
+    # Build a pipeline with the correct ordering:
+    #   1. library_dispatch + LegalizeOps + FuseOps + FuseTIR
+    #      (same preparation as extract_tasks, so database keys match)
+    #   2. MetaScheduleApplyDatabase — replaces tuned fused-TIR functions
+    #   3. DLight fallback — schedules remaining untuned functions
+    #   4. dataflow_lower + finalize passes
+    #
+    # Applying MetaScheduleApplyDatabase BEFORE FuseOps (the original bug)
+    # caused DLight.Matmul to fail on cache-write stages embedded in fused TIR.
+    #
+    # All pass lists are obtained from relax.pipeline.*_passes(target) so that
+    # target-specific helpers (dispatch, finalize) are shared with the default
+    # pipeline rather than duplicated here.
+    try:
+        dispatch_passes = relax_pipeline_mod.library_dispatch_passes(target)
+    except (ValueError, AttributeError):
+        dispatch_passes = []
+
+    try:
+        lower_passes = relax_pipeline_mod.dataflow_lower_passes(target)
+        finalize_passes = relax_pipeline_mod.finalize_passes(target)
+    except (ValueError, AttributeError):
+        # Fallback for targets not yet registered in the pipeline dispatcher
+        lower_passes = [
+            relax.transform.RewriteDataflowReshape(),
+            relax.transform.ToNonDataflow(),
+            relax.transform.RemovePurityChecking(),
+            relax.transform.CallTIRRewrite(),
+        ]
+        finalize_passes = [
+            relax.transform.StaticPlanBlockMemory(),
+            relax.transform.LowerAllocTensor(),
+            relax.transform.KillAfterLastUse(),
+            relax.transform.LowerRuntimeBuiltin(),
+            relax.transform.ComputePrimValue(),
+            relax.transform.VMShapeLower(),
+            relax.transform.AttachGlobalSymbol(),
+        ]
+
+    is_gpu_target = relax_pipeline_mod.BackendDispatcher.is_gpu_target(target)
+
+    @tvm.transform.module_pass(opt_level=3)
+    def _ms_pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> 
tvm.ir.IRModule:
+        fuse_seq = dispatch_passes + [
+            relax.transform.LegalizeOps(enable_warning=enable_warning),
+            relax.transform.AnnotateTIROpPattern(),
+            relax.transform.FoldConstant(),
+            relax.transform.FuseOps(),
+            relax.transform.FuseTIR(),
+        ]
+        mod = tvm.transform.Sequential(fuse_seq)(mod)
+        mod = MetaScheduleApplyDatabase(enable_warning=enable_warning)(mod)
+        # DLight handles functions not covered by the database.
+        # GPU rules apply only for GPU targets.
+        if is_gpu_target:
+            mod = dl.ApplyDefaultSchedule(
+                dl.gpu.Matmul(),
+                dl.gpu.GEMV(),
+                dl.gpu.Reduction(),
+                dl.gpu.GeneralReduction(),
+                dl.gpu.Fallback(),
+            )(mod)
+        mod = tvm.transform.Sequential(lower_passes + finalize_passes)(mod)
+        return mod
+
     with target, database, PassContext(opt_level=3):
-        relax_mod = 
MetaScheduleApplyDatabase(enable_warning=enable_warning)(mod)
-        relax_ex = relax_build(relax_mod, target=target)
+        relax_ex = relax_build(mod, target=target, relax_pipeline=_ms_pipeline)
     return relax_ex
diff --git a/tests/python/relax/test_meta_schedule_relax_integration.py 
b/tests/python/relax/test_meta_schedule_relax_integration.py
index 72699a2c6f..c28b8c444b 100644
--- a/tests/python/relax/test_meta_schedule_relax_integration.py
+++ b/tests/python/relax/test_meta_schedule_relax_integration.py
@@ -17,9 +17,17 @@
 # ruff: noqa: E501, F401, F841
 """Integration test for MetaSchedule"""
 
+import tempfile
+
+import numpy as np
+import pytest
+
 import tvm
 import tvm.testing
 from tvm import relax
+from tvm.runtime import tensor as tvm_tensor
+from tvm.runtime import cpu as tvm_cpu
+from tvm.runtime.vm import VirtualMachine
 from tvm.s_tir import meta_schedule as ms
 from tvm.script import ir as I
 from tvm.script import relax as R
@@ -73,5 +81,47 @@ def test_extracting_tasks():
         assert len(extracted_tasks) == count
 
 
+def test_compile_relax_with_database():
+    """End-to-end test: tune with MetaSchedule then compile_relax with the 
database.
+
+    Verifies that the pipeline ordering in compile_relax is correct: tasks are
+    extracted and tuned against fused-TIR keys, and compile_relax produces 
those
+    same keys (by running LegalizeOps + FuseOps + FuseTIR before applying the
+    database), so the scheduled kernels are actually picked up.
+    """
+    target = tvm.target.Target({"kind": "llvm", "num-cores": 1})
+
+    # Prepare the fused module whose TIR keys will populate the database.
+    fused_mod = Module0
+    fused_mod = relax.transform.LegalizeOps()(fused_mod)
+    fused_mod = relax.transform.AnnotateTIROpPattern()(fused_mod)
+    fused_mod = relax.transform.FuseOps()(fused_mod)
+    fused_mod = relax.transform.FoldConstant()(fused_mod)
+    fused_mod = relax.transform.FuseTIR()(fused_mod)
+
+    with tempfile.TemporaryDirectory() as work_dir:
+        database = ms.relax_integration.tune_relax(
+            fused_mod,
+            params={},
+            target=target,
+            work_dir=work_dir,
+            max_trials_global=4,
+        )
+        # compile_relax takes the raw module and builds the fused-TIR pipeline
+        # internally; the database keys must therefore match the ones above.
+        exe = ms.relax_integration.compile_relax(
+            database=database,
+            mod=Module0,
+            target=target,
+            params=None,
+        )
+
+    dev = tvm_cpu()
+    vm = VirtualMachine(exe.jit(), dev)
+    data = tvm_tensor(np.zeros((1, 8, 8, 4), dtype="int32"), device=dev)
+    result = vm["main"](data)
+    assert result.numpy().shape == (1, 8, 8, 4)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to