This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1808a94696 [Fix] Update FlashInfer JIT header lookup (#18244)
1808a94696 is described below
commit 1808a9469628047f9c1b90fc39c26489e0fe1671
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Aug 27 17:28:41 2025 -0400
[Fix] Update FlashInfer JIT header lookup (#18244)
This PR fixes the tvm/dlpack/dmlc header lookup in the FlashInfer
kernel JIT compilation.
Prior to this fix, the JIT compilation assumes the environment
variable `TVM_SOURCE_DIR` is always defined, which is not always
true. This PR fixes the behavior and considers multiple cases,
including TVM source builds and pip-installed packages.
---
python/tvm/libinfo.py | 6 ++--
python/tvm/relax/backend/cuda/flashinfer.py | 46 +++++++++++++++++++++++++----
2 files changed, 45 insertions(+), 7 deletions(-)
diff --git a/python/tvm/libinfo.py b/python/tvm/libinfo.py
index f9f28b6853..69429179fc 100644
--- a/python/tvm/libinfo.py
+++ b/python/tvm/libinfo.py
@@ -195,7 +195,9 @@ def find_include_path(name=None, search_path=None,
optional=False):
include_path : list(string)
List of all found paths to header files.
"""
- if os.environ.get("TVM_HOME", None):
+ if os.environ.get("TVM_SOURCE_DIR", None):
+ source_dir = os.environ["TVM_SOURCE_DIR"]
+ elif os.environ.get("TVM_HOME", None):
source_dir = os.environ["TVM_HOME"]
else:
ffi_dir =
os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
@@ -204,7 +206,7 @@ def find_include_path(name=None, search_path=None,
optional=False):
if os.path.isdir(os.path.join(source_dir, "include")):
break
else:
- raise AssertionError("Cannot find the source directory given
ffi_dir: {ffi_dir}")
+ raise AssertionError(f"Cannot find the source directory given
ffi_dir: {ffi_dir}")
third_party_dir = os.path.join(source_dir, "3rdparty")
header_path = []
diff --git a/python/tvm/relax/backend/cuda/flashinfer.py
b/python/tvm/relax/backend/cuda/flashinfer.py
index 0f81675a8f..1fea39e9a2 100644
--- a/python/tvm/relax/backend/cuda/flashinfer.py
+++ b/python/tvm/relax/backend/cuda/flashinfer.py
@@ -24,6 +24,8 @@ from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import List
+import tvm_ffi
+
import tvm
from tvm.target import Target
@@ -124,17 +126,51 @@ def _compile_flashinfer_kernels(
# ------------------------------------------------------------------------
# 2) Include paths
# ------------------------------------------------------------------------
- tvm_home = os.environ["TVM_SOURCE_DIR"]
include_paths = [
FLASHINFER_INCLUDE_DIR,
FLASHINFER_CSRC_DIR,
FLASHINFER_TVM_BINDING_DIR,
- Path(tvm_home).resolve() / "include",
- Path(tvm_home).resolve() / "ffi" / "include",
- Path(tvm_home).resolve() / "ffi" / "3rdparty" / "dlpack" / "include",
- Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include",
] + CUTLASS_INCLUDE_DIRS
+ if os.environ.get("TVM_SOURCE_DIR", None) or os.environ.get("TVM_HOME",
None):
+ # Respect TVM_SOURCE_DIR and TVM_HOME if they are set
+ tvm_home = (
+ os.environ["TVM_SOURCE_DIR"]
+ if os.environ.get("TVM_SOURCE_DIR", None)
+ else os.environ["TVM_HOME"]
+ )
+ include_paths += [
+ Path(tvm_home).resolve() / "include",
+ Path(tvm_home).resolve() / "ffi" / "include",
+ Path(tvm_home).resolve() / "ffi" / "3rdparty" / "dlpack" /
"include",
+ Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include",
+ ]
+ else:
+ # If TVM_SOURCE_DIR and TVM_HOME are not set, use the default TVM
package path
+ tvm_package_path = Path(tvm.__file__).resolve().parent
+ if (tvm_package_path / "include").exists():
+ # The package is installed from pip.
+ tvm_ffi_package_path = Path(tvm_ffi.__file__).resolve().parent
+ include_paths += [
+ tvm_package_path / "include",
+ tvm_package_path / "3rdparty" / "dmlc-core" / "include",
+ tvm_ffi_package_path / "include",
+ ]
+ elif (tvm_package_path.parent.parent / "include").exists():
+ # The package is installed from source.
+ include_paths += [
+ tvm_package_path.parent.parent / "include",
+ tvm_package_path.parent.parent / "ffi" / "include",
+ tvm_package_path.parent.parent / "ffi" / "3rdparty" / "dlpack"
/ "include",
+ tvm_package_path.parent.parent / "3rdparty" / "dmlc-core" /
"include",
+ ]
+ else:
+ # warning: TVM is not installed in the system.
+ print(
+ "Warning: Include path for TVM cannot be found. "
+ "FlashInfer kernel compilation may fail due to missing
headers."
+ )
+
# ------------------------------------------------------------------------
# 3) Function to compile a single source file
# ------------------------------------------------------------------------