This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a7895a301a [Attention] Added caching for flashinfer binaries during
JIT (#17730)
a7895a301a is described below
commit a7895a301a976517c1314f3a7f52d083180a7126
Author: Annanya <[email protected]>
AuthorDate: Thu Mar 13 09:25:34 2025 -0400
[Attention] Added caching for flashinfer binaries during JIT (#17730)
In this PR I have added support for caching the flashinfer binaries
during the JIT flow in TVM.
---
python/tvm/relax/backend/cuda/flashinfer.py | 80 +++++++++++++++++++---
.../test_runtime_builtin_kv_cache_transfer.py | 2 +-
..._builtin_paged_attention_kv_cache_flashinfer.py | 2 +-
...ltin_paged_attention_kv_cache_mla_flashinfer.py | 2 +-
4 files changed, 73 insertions(+), 13 deletions(-)
diff --git a/python/tvm/relax/backend/cuda/flashinfer.py
b/python/tvm/relax/backend/cuda/flashinfer.py
index 725fd105ad..8aa4817a30 100644
--- a/python/tvm/relax/backend/cuda/flashinfer.py
+++ b/python/tvm/relax/backend/cuda/flashinfer.py
@@ -21,6 +21,8 @@ import subprocess
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import List
+import hashlib
+import json
import tvm
from tvm.target import Target
@@ -37,7 +39,57 @@ def _compile_flashinfer_kernels(
FLASHINFER_TVM_BINDING_DIR,
)
- # Todo(tvm-team): enable compilation cache
+ # ------------------------------------------------------------------------
+ # Caching Flow: create build_directory and compute cache hash.
+ # ------------------------------------------------------------------------
+ build_directory = FLASHINFER_JIT_DIR / name
+ build_directory.mkdir(parents=True, exist_ok=True)
+
+ def get_object_file_path(src: Path) -> Path:
+ obj_name = src.stem + ".o"
+ obj_path = build_directory / obj_name
+ return obj_path
+
+ # Compute latest modification time among all source files
+ latest_src_mtime = max(src.stat().st_mtime for src in source_paths)
+
+ # Get modification time for the current file (the one that contains this
function)
+ current_file_mtime = Path(__file__).stat().st_mtime
+
+ # Build the hash key from metadata
+ hash_key = {
+ "name": name,
+ "target": str(target),
+ "latest_src_mtime": latest_src_mtime,
+ "current_file_mtime": current_file_mtime,
+ }
+
+ hash_value = hashlib.md5(
+ json.dumps(hash_key, sort_keys=True, indent=2).encode("utf-8")
+ ).hexdigest()
+
+ # Check if a valid hash exists in the build directory
+ hash_file = build_directory / "hash.md5"
+ if hash_file.exists():
+ with open(hash_file, "r") as f:
+ cached_hash = f.read().strip()
+ if cached_hash == hash_value:
+ # Check that all object files exist
+ object_files = []
+ all_exist = True
+ for src in source_paths:
+ obj_path = get_object_file_path(src)
+ if not obj_path.exists():
+ all_exist = False
+ break
+ object_files.append(obj_path)
+ if all_exist:
+ return object_files
+
+ # If we are here, cache is missing or outdated. Write the new hash and
compile the paths
+ with open(hash_file, "w") as f:
+ f.write(hash_value)
+
# ------------------------------------------------------------------------
# 1) Common CUDA compile flags
# ------------------------------------------------------------------------
@@ -82,17 +134,12 @@ def _compile_flashinfer_kernels(
Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include",
] + CUTLASS_INCLUDE_DIRS
- # Where object files will be placed
- build_directory = FLASHINFER_JIT_DIR / name
- build_directory.mkdir(parents=True, exist_ok=True)
-
# ------------------------------------------------------------------------
# 3) Function to compile a single source file
# ------------------------------------------------------------------------
def compile_single_source(src: Path) -> Path:
# Derive the .o filename from the source filename
- obj_name = src.stem + ".o"
- obj_path = build_directory / obj_name
+ obj_path = get_object_file_path(src)
# Construct the command
cmd = (
@@ -202,7 +249,12 @@ def gen_flashinfer_prefill_module(
)
jit_args = {
"backend": backend,
- "uri": "batch_prefill_tvm",
+ "uri": f"batch_prefill_tvm_dtype_q_{dtype_q}_"
+ + f"dtype_kv_{dtype_kv}_"
+ + f"dtype_o_{dtype_o}_"
+ + f"qk_head_dim_{qk_head_dim}_"
+ + f"v_head_dim_{v_head_dim}_"
+ + f"enable_inline_rope_{enable_inline_rope}",
"dtype_q": torch_dtype_q,
"dtype_kv": torch_dtype_kv,
"dtype_o": torch_dtype_o,
@@ -273,7 +325,11 @@ def gen_flashinfer_decode_module(
torch_dtype_kv = getattr(torch, dtype_kv)
torch_dtype_o = getattr(torch, dtype_o)
jit_args = {
- "uri": "batch_decode_tvm",
+ "uri": f"batch_decode_tvm_dtype_q_{dtype_q}_"
+ + f"dtype_kv_{dtype_kv}_"
+ + f"dtype_o_{dtype_o}_"
+ + f"qk_head_dim_{qk_head_dim}_"
+ + f"v_head_dim_{v_head_dim}",
"dtype_q": torch_dtype_q,
"dtype_kv": torch_dtype_kv,
"dtype_o": torch_dtype_o,
@@ -343,7 +399,11 @@ def gen_flashinfer_mla_module(
torch_dtype_kv = getattr(torch, dtype_kv)
torch_dtype_o = getattr(torch, dtype_o)
jit_args = {
- "uri": "batch_mla_tvm",
+ "uri": f"batch_mla_tvm_dtype_q_{dtype_q}_"
+ + f"dtype_kv_{dtype_kv}_"
+ + f"dtype_o_{dtype_o}_"
+ + f"head_dim_ckv_{head_dim_ckv}_"
+ + f"head_dim_kpe_{head_dim_kpe}",
"dtype_q": torch_dtype_q,
"dtype_kv": torch_dtype_kv,
"dtype_o": torch_dtype_o,
diff --git
a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
index 3e17f64366..81acf5ee86 100644
--- a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
+++ b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
@@ -169,7 +169,7 @@ def set_global_func(head_dim, dtype):
mod = tvm.IRModule({"main": tir_func})
with target:
mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
- f = tvm.compile(mod["main"], target=target)
+ f = tvm.tir.build(mod["main"], target=target)
builts.append(f.entry_func)
(
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index 41743efeee..ffd3452292 100644
---
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -155,7 +155,7 @@ def set_global_func():
mod = tvm.IRModule({"main": tir_func})
with target:
mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
- f = tvm.compile(mod["main"], target=target)
+ f = tvm.tir.build(mod["main"], target=target)
builts.append(f.entry_func)
(
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
index 84b50125ee..2f726064a7 100644
---
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
+++
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
@@ -168,7 +168,7 @@ def set_global_func(dtype):
mod = tvm.IRModule({"main": tir_func})
with target:
mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
- f = tvm.compile(mod["main"], target=target)
+ f = tvm.tir.build(mod["main"], target=target)
builts.append(f.entry_func)
(