This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a7895a301a [Attention] Added caching for flashinfer binaries during 
JIT (#17730)
a7895a301a is described below

commit a7895a301a976517c1314f3a7f52d083180a7126
Author: Annanya <[email protected]>
AuthorDate: Thu Mar 13 09:25:34 2025 -0400

    [Attention] Added caching for flashinfer binaries during JIT (#17730)
    
    In this PR I have added support for caching the flashinfer binaries
    during the JIT flow in TVM.
---
 python/tvm/relax/backend/cuda/flashinfer.py        | 80 +++++++++++++++++++---
 .../test_runtime_builtin_kv_cache_transfer.py      |  2 +-
 ..._builtin_paged_attention_kv_cache_flashinfer.py |  2 +-
 ...ltin_paged_attention_kv_cache_mla_flashinfer.py |  2 +-
 4 files changed, 73 insertions(+), 13 deletions(-)

diff --git a/python/tvm/relax/backend/cuda/flashinfer.py 
b/python/tvm/relax/backend/cuda/flashinfer.py
index 725fd105ad..8aa4817a30 100644
--- a/python/tvm/relax/backend/cuda/flashinfer.py
+++ b/python/tvm/relax/backend/cuda/flashinfer.py
@@ -21,6 +21,8 @@ import subprocess
 from concurrent.futures import ThreadPoolExecutor
 from pathlib import Path
 from typing import List
+import hashlib
+import json
 
 import tvm
 from tvm.target import Target
@@ -37,7 +39,57 @@ def _compile_flashinfer_kernels(
         FLASHINFER_TVM_BINDING_DIR,
     )
 
-    # Todo(tvm-team): enable compilation cache
+    # ------------------------------------------------------------------------
+    # Caching Flow: create build_directory and compute cache hash.
+    # ------------------------------------------------------------------------
+    build_directory = FLASHINFER_JIT_DIR / name
+    build_directory.mkdir(parents=True, exist_ok=True)
+
+    def get_object_file_path(src: Path) -> Path:
+        obj_name = src.stem + ".o"
+        obj_path = build_directory / obj_name
+        return obj_path
+
+    # Compute latest modification time among all source files
+    latest_src_mtime = max(src.stat().st_mtime for src in source_paths)
+
+    # Get modification time for the current file (the one that contains this 
function)
+    current_file_mtime = Path(__file__).stat().st_mtime
+
+    # Build the hash key from metadata
+    hash_key = {
+        "name": name,
+        "target": str(target),
+        "latest_src_mtime": latest_src_mtime,
+        "current_file_mtime": current_file_mtime,
+    }
+
+    hash_value = hashlib.md5(
+        json.dumps(hash_key, sort_keys=True, indent=2).encode("utf-8")
+    ).hexdigest()
+
+    # Check if a valid hash exists in the build directory
+    hash_file = build_directory / "hash.md5"
+    if hash_file.exists():
+        with open(hash_file, "r") as f:
+            cached_hash = f.read().strip()
+        if cached_hash == hash_value:
+            # Check that all object files exist
+            object_files = []
+            all_exist = True
+            for src in source_paths:
+                obj_path = get_object_file_path(src)
+                if not obj_path.exists():
+                    all_exist = False
+                    break
+                object_files.append(obj_path)
+            if all_exist:
+                return object_files
+
+    # If we are here, cache is missing or outdated. Write the new hash and 
compile the paths
+    with open(hash_file, "w") as f:
+        f.write(hash_value)
+
     # ------------------------------------------------------------------------
     # 1) Common CUDA compile flags
     # ------------------------------------------------------------------------
@@ -82,17 +134,12 @@ def _compile_flashinfer_kernels(
         Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include",
     ] + CUTLASS_INCLUDE_DIRS
 
-    # Where object files will be placed
-    build_directory = FLASHINFER_JIT_DIR / name
-    build_directory.mkdir(parents=True, exist_ok=True)
-
     # ------------------------------------------------------------------------
     # 3) Function to compile a single source file
     # ------------------------------------------------------------------------
     def compile_single_source(src: Path) -> Path:
         # Derive the .o filename from the source filename
-        obj_name = src.stem + ".o"
-        obj_path = build_directory / obj_name
+        obj_path = get_object_file_path(src)
 
         # Construct the command
         cmd = (
@@ -202,7 +249,12 @@ def gen_flashinfer_prefill_module(
     )
     jit_args = {
         "backend": backend,
-        "uri": "batch_prefill_tvm",
+        "uri": f"batch_prefill_tvm_dtype_q_{dtype_q}_"
+        + f"dtype_kv_{dtype_kv}_"
+        + f"dtype_o_{dtype_o}_"
+        + f"qk_head_dim_{qk_head_dim}_"
+        + f"v_head_dim_{v_head_dim}_"
+        + f"enable_inline_rope_{enable_inline_rope}",
         "dtype_q": torch_dtype_q,
         "dtype_kv": torch_dtype_kv,
         "dtype_o": torch_dtype_o,
@@ -273,7 +325,11 @@ def gen_flashinfer_decode_module(
     torch_dtype_kv = getattr(torch, dtype_kv)
     torch_dtype_o = getattr(torch, dtype_o)
     jit_args = {
-        "uri": "batch_decode_tvm",
+        "uri": f"batch_decode_tvm_dtype_q_{dtype_q}_"
+        + f"dtype_kv_{dtype_kv}_"
+        + f"dtype_o_{dtype_o}_"
+        + f"qk_head_dim_{qk_head_dim}_"
+        + f"v_head_dim_{v_head_dim}",
         "dtype_q": torch_dtype_q,
         "dtype_kv": torch_dtype_kv,
         "dtype_o": torch_dtype_o,
@@ -343,7 +399,11 @@ def gen_flashinfer_mla_module(
     torch_dtype_kv = getattr(torch, dtype_kv)
     torch_dtype_o = getattr(torch, dtype_o)
     jit_args = {
-        "uri": "batch_mla_tvm",
+        "uri": f"batch_mla_tvm_dtype_q_{dtype_q}_"
+        + f"dtype_kv_{dtype_kv}_"
+        + f"dtype_o_{dtype_o}_"
+        + f"head_dim_ckv_{head_dim_ckv}_"
+        + f"head_dim_kpe_{head_dim_kpe}",
         "dtype_q": torch_dtype_q,
         "dtype_kv": torch_dtype_kv,
         "dtype_o": torch_dtype_o,
diff --git 
a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py 
b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
index 3e17f64366..81acf5ee86 100644
--- a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
+++ b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
@@ -169,7 +169,7 @@ def set_global_func(head_dim, dtype):
         mod = tvm.IRModule({"main": tir_func})
         with target:
             mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
-        f = tvm.compile(mod["main"], target=target)
+        f = tvm.tir.build(mod["main"], target=target)
         builts.append(f.entry_func)
 
     (
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index 41743efeee..ffd3452292 100644
--- 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++ 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -155,7 +155,7 @@ def set_global_func():
         mod = tvm.IRModule({"main": tir_func})
         with target:
             mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
-        f = tvm.compile(mod["main"], target=target)
+        f = tvm.tir.build(mod["main"], target=target)
         builts.append(f.entry_func)
 
     (
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
index 84b50125ee..2f726064a7 100644
--- 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
+++ 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
@@ -168,7 +168,7 @@ def set_global_func(dtype):
         mod = tvm.IRModule({"main": tir_func})
         with target:
             mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
-        f = tvm.compile(mod["main"], target=target)
+        f = tvm.tir.build(mod["main"], target=target)
         builts.append(f.entry_func)
 
     (

Reply via email to