This is an automated email from the ASF dual-hosted git repository.
MasterJH5574 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a8a94184b5 [REFACTOR][PYTHON] Consolidate backend autoload infra
(#19769)
a8a94184b5 is described below
commit a8a94184b54fd96b4da681da4318933c032a4648
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Jun 15 13:02:54 2026 -0400
[REFACTOR][PYTHON] Consolidate backend autoload infra (#19769)
## Summary
Backend loading is easier to maintain when native backend library
discovery, in-tree backend Python hook loading, and out-of-tree entry
point autoload are owned by the backend namespace. This PR consolidates
those paths under `tvm.backend._autoload_backends` while preserving
compatibility routes from the previous top-level helper and
`tvm.base.load_backend_libs`.
- Move backend runtime DSO loading into `tvm.backend._autoload_backends`
- Route `backend.load_all()` through the backend autoload helper
- Keep the previous top-level `_autoload_backends` module as a thin
compatibility import
---
python/tvm/__init__.py | 13 +-
python/tvm/_autoload_backends.py | 50 ------
python/tvm/backend/__init__.py | 186 +--------------------
python/tvm/backend/_autoload_backends.py | 88 ++++++++++
python/tvm/backend/cuda/__init__.py | 32 ++++
python/tvm/backend/hexagon/__init__.py | 16 ++
python/tvm/backend/{__init__.py => loader.py} | 26 +--
python/tvm/backend/metal/__init__.py | 31 ++++
python/tvm/backend/opencl/__init__.py | 33 ++++
python/tvm/backend/rocm/__init__.py | 34 ++++
python/tvm/backend/vulkan/__init__.py | 44 +++++
python/tvm/base.py | 25 ---
.../meta_schedule/space_generator/__init__.py | 2 -
python/tvm/target/detect_target.py | 100 +++--------
python/tvm/target/x86.py | 39 -----
tests/python/tirx/test_op_namespace_cleanup.py | 4 +-
16 files changed, 309 insertions(+), 414 deletions(-)
diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index b38e86a0ae..b6d928edce 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -75,8 +75,6 @@ from .support import rocm as _rocm, nvcc as _nvcc
# Relax contain modules that are only available in compiler package
# Do not import them if TVM is built with runtime only
if not _RUNTIME_ONLY:
- backend.load_all()
-
# tvm.relax — registers itself via tvm.script.register_dialect in its
__init__
from . import relax
@@ -118,10 +116,11 @@ def tvm_wrap_excepthook(exception_hook):
sys.excepthook = tvm_wrap_excepthook(sys.excepthook)
-# Autoload out-of-tree backends registered under the ``tvm.backends`` entry
-# point group. Runs last, after the core runtime and the tvm namespace are
-# fully initialized, so an extension can safely register into ``tvm.*`` and
-# load extra libraries. Imported lazily here to avoid any import-cycle risk.
-from ._autoload_backends import _autoload_backends
+# Autoload loads built-in and out-of-tree backends. Out-of-tree extensions opt
+# into being loaded automatically at ``import tvm`` time by declaring an entry
+# point in the ``tvm.backends`` group:
+# [project.entry-points."tvm.backends"] tvm_foo = "tvm_foo:_autoload".
+# Autoload can be disabled via ``TVM_DEVICE_BACKEND_AUTOLOAD=0``.
+from .backend._autoload_backends import _autoload_backends
_autoload_backends()
diff --git a/python/tvm/_autoload_backends.py b/python/tvm/_autoload_backends.py
deleted file mode 100644
index b45ac59d9d..0000000000
--- a/python/tvm/_autoload_backends.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Autoload out-of-tree backends registered via ``tvm.backends`` entry points.
-
-Out-of-tree extensions opt into being loaded automatically at ``import tvm``
-time by declaring an entry point in the ``tvm.backends`` group::
-
- [project.entry-points."tvm.backends"]
- tvm_foo = "tvm_foo:_autoload"
-
-Autoload can be disabled via ``TVM_DEVICE_BACKEND_AUTOLOAD=0``.
-"""
-
-import os
-import warnings
-from importlib.metadata import entry_points
-
-# Guard so autoload runs at most once per process, even if invoked again.
-_AUTO_LOAD_DONE = False
-
-
-def _autoload_backends():
- """Discover and invoke out-of-tree backends registered via entry points."""
- global _AUTO_LOAD_DONE
- if _AUTO_LOAD_DONE:
- return
- _AUTO_LOAD_DONE = True
-
- if os.environ.get("TVM_DEVICE_BACKEND_AUTOLOAD", "1") == "0":
- return
-
- for entry_pt in entry_points(group="tvm.backends"):
- try:
- entry_pt.load()()
- except Exception as e: # pylint: disable=broad-except
- warnings.warn(f"Failed to autoload tvm backend '{entry_pt.name}':
{e}")
diff --git a/python/tvm/backend/__init__.py b/python/tvm/backend/__init__.py
index 2243f5e034..009367cf22 100644
--- a/python/tvm/backend/__init__.py
+++ b/python/tvm/backend/__init__.py
@@ -18,192 +18,10 @@
from __future__ import annotations
-import importlib
-import importlib.util
-import sys
-import types
from pkgutil import extend_path
-from typing import Any
__path__ = extend_path(__path__, __name__) # type: ignore[name-defined]
-_BUILTIN_BACKENDS = (
- "cuda",
- "metal",
- "rocm",
- "trn",
- "opencl",
- "vulkan",
- "webgpu",
- "hexagon",
- "adreno",
-)
-_LOADED_BACKENDS: dict[str, Any] = {}
+from .loader import is_loaded, load
-
-class _AliasModule(types.ModuleType):
- """Module object that exposes a backend module under a public alias."""
-
- def __init__(self, fullname: str, module):
- super().__init__(fullname, getattr(module, "__doc__", None))
- self.__dict__["__tvm_backend_module__"] = module
- self.__dict__["__package__"] = fullname.rpartition(".")[0]
- if hasattr(module, "__all__"):
- self.__dict__["__all__"] = module.__all__
- if hasattr(module, "__path__"):
- self.__dict__["__path__"] = []
-
- def __getattr__(self, name: str):
- return getattr(self.__dict__["__tvm_backend_module__"], name)
-
- def __setattr__(self, name: str, value):
- setattr(self.__dict__["__tvm_backend_module__"], name, value)
-
- def __delattr__(self, name: str):
- delattr(self.__dict__["__tvm_backend_module__"], name)
-
- def __dir__(self):
- return sorted(set(super().__dir__()) |
set(dir(self.__dict__["__tvm_backend_module__"])))
-
-
-class _AliasLoader:
- """Loader that returns an already-resolved module for an alias spec."""
-
- def __init__(self, fullname: str, module):
- self._fullname = fullname
- self._module = module
-
- def create_module(self, spec):
- return _get_alias_module(self._fullname, self._module)
-
- def exec_module(self, module):
- _set_module_alias(self._fullname, self._module)
- return None
-
- def is_package(self, fullname):
- return hasattr(self._module, "__path__")
-
-
-def _redirect_tirx_backend_alias(fullname: str) -> str | None:
- prefix = "tvm.tirx."
- if not fullname.startswith(prefix):
- return None
- rest = fullname[len(prefix) :]
- backend_name, sep, tail = rest.partition(".")
- if not sep or backend_name not in _LOADED_BACKENDS:
- return None
- return f"tvm.backend.{backend_name}.{tail}"
-
-
-class _BackendAliasFinder:
- """Redirect ``tvm.tirx.<backend>.*`` imports to
``tvm.backend.<backend>.*``."""
-
- @classmethod
- def find_spec(cls, fullname, path, target=None):
- redirected = _redirect_tirx_backend_alias(fullname)
- if redirected is None:
- return None
- module = importlib.import_module(redirected)
- _set_module_alias(fullname, module)
- loader = _AliasLoader(fullname, module)
- spec = importlib.util.spec_from_loader(
- fullname, loader, is_package=hasattr(module, "__path__")
- )
- if spec is not None and hasattr(module, "__path__"):
- spec.submodule_search_locations = []
- return spec
-
-
-if not any(isinstance(finder, _BackendAliasFinder) for finder in
sys.meta_path):
- sys.meta_path.insert(0, _BackendAliasFinder())
-
-
-def _get_alias_module(alias: str, module):
- existing = sys.modules.get(alias)
- if (
- isinstance(existing, _AliasModule)
- and existing.__dict__.get("__tvm_backend_module__") is module
- ):
- return existing
- return _AliasModule(alias, module)
-
-
-def _set_module_alias(alias: str, module, *, direct: bool = False) -> None:
- alias_module = module if direct else _get_alias_module(alias, module)
- sys.modules[alias] = alias_module
- parent_name, _, child_name = alias.rpartition(".")
- parent = sys.modules.get(parent_name)
- if parent is not None:
- setattr(parent, child_name, alias_module)
-
-
-def _alias_loaded_backend_modules(name: str) -> None:
- backend_prefix = f"tvm.backend.{name}"
- public_prefix = f"tvm.tirx.{name}"
- for module_name, module in sorted(list(sys.modules.items())):
- if module_name == backend_prefix or
module_name.startswith(f"{backend_prefix}."):
- public_name = f"{public_prefix}{module_name[len(backend_prefix)
:]}"
- _set_module_alias(public_name, module, direct=module_name ==
backend_prefix)
-
-
-def _import_backend(name: str):
- module_name = f"tvm.backend.{name}"
- try:
- return importlib.import_module(module_name)
- except ModuleNotFoundError as err:
- if err.name == module_name:
- raise ImportError(
- f"Cannot load TVM backend {name!r}: expected Python package
{module_name!r}. "
- "Install the backend package or check the backend name."
- ) from err
- raise
-
-
-def load(name: str) -> None:
- """Load a backend's Python registration hooks.
-
- Loading is idempotent. A backend package must live at
``tvm.backend.<name>``
- and expose ``register_backend()``.
- """
-
- if name in _LOADED_BACKENDS:
- return None
-
- module = _import_backend(name)
- register_backend = getattr(module, "register_backend", None)
- if register_backend is None:
- raise AttributeError(f"Backend package 'tvm.backend.{name}' has no
register_backend()")
-
- import tvm.tirx as tirx # pylint: disable=import-outside-toplevel
-
- setattr(tirx, name, module)
- sys.modules[f"tvm.tirx.{name}"] = module
- _LOADED_BACKENDS[name] = module
- try:
- register_backend()
- _alias_loaded_backend_modules(name)
- except Exception:
- _LOADED_BACKENDS.pop(name, None)
- if getattr(tirx, name, None) is module:
- delattr(tirx, name)
- if sys.modules.get(f"tvm.tirx.{name}") is module:
- sys.modules.pop(f"tvm.tirx.{name}", None)
- raise
- return None
-
-
-def load_all() -> None:
- """Load all in-tree backend Python hooks."""
-
- for name in _BUILTIN_BACKENDS:
- load(name)
- return None
-
-
-def is_loaded(name: str) -> bool:
- """Return whether a backend has been loaded."""
-
- return name in _LOADED_BACKENDS
-
-
-__all__ = ["is_loaded", "load", "load_all"]
+__all__ = ["is_loaded", "load"]
diff --git a/python/tvm/backend/_autoload_backends.py
b/python/tvm/backend/_autoload_backends.py
new file mode 100644
index 0000000000..48967efb63
--- /dev/null
+++ b/python/tvm/backend/_autoload_backends.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Autoload built-in and out-of-tree backend libraries and registration
hooks."""
+
+from __future__ import annotations
+
+import os
+import warnings
+from importlib.metadata import entry_points
+from pathlib import Path
+
+from tvm_ffi.libinfo import load_lib_ctypes
+
+from tvm.backend.loader import load
+from tvm.base import _LOADED_LIBS
+
+_BUILTIN_BACKENDS = (
+ "cuda",
+ "metal",
+ "rocm",
+ "trn",
+ "opencl",
+ "vulkan",
+ "webgpu",
+ "hexagon",
+ "adreno",
+)
+
+_AUTO_LOAD_DONE = False
+
+
+def _load_builtin_backends() -> None:
+ """Load all in-tree backend Python hooks."""
+ for name in _BUILTIN_BACKENDS:
+ load(name)
+
+ runtime_dir = Path(_LOADED_LIBS["tvm_runtime"]._name).resolve().parent
+ try:
+ # Load libtvm_runtime_extra if available for registration side effects.
+ _LOADED_LIBS["tvm_runtime_extra"] = load_lib_ctypes(
+ package="tvm",
+ target_name="tvm_runtime_extra",
+ extra_lib_paths=[runtime_dir],
+ mode="RTLD_LOCAL",
+ )
+ except (OSError, FileNotFoundError, RuntimeError):
+ pass
+ return None
+
+
+def _autoload_backends() -> None:
+ """Load built-in backends and invoke out-of-tree backend entry points."""
+ global _AUTO_LOAD_DONE
+ if _AUTO_LOAD_DONE:
+ return
+ _AUTO_LOAD_DONE = True
+
+ if os.environ.get("TVM_DEVICE_BACKEND_AUTOLOAD", "1") == "0":
+ return
+
+ from tvm import _RUNTIME_ONLY # pylint: disable=import-outside-toplevel
+
+ if not _RUNTIME_ONLY:
+ _load_builtin_backends()
+
+ # Out-of-tree extensions opt into being loaded automatically at ``import
tvm`` time
+ # by declaring an entry point in the ``tvm.backends`` group:
+ # [project.entry-points."tvm.backends"] tvm_foo = "tvm_foo:_autoload".
+ # Autoload can be disabled via ``TVM_DEVICE_BACKEND_AUTOLOAD=0``.
+ for entry_pt in entry_points(group="tvm.backends"):
+ try:
+ entry_pt.load()()
+ except Exception as e: # pylint: disable=broad-except
+ warnings.warn(f"Failed to autoload tvm backend '{entry_pt.name}':
{e}")
diff --git a/python/tvm/backend/cuda/__init__.py
b/python/tvm/backend/cuda/__init__.py
index 34592d2de3..a875f79e04 100644
--- a/python/tvm/backend/cuda/__init__.py
+++ b/python/tvm/backend/cuda/__init__.py
@@ -17,14 +17,46 @@
"""CUDA-owned TIRx modules."""
from importlib import import_module
+from pathlib import Path
+
+from tvm_ffi.libinfo import load_lib_ctypes
+
+from tvm.base import _LOADED_LIBS
_LAZY_SUBMODULES = {"lang", "op", "operator", "script", "target_tags"}
+def _detect_target_from_device(dev):
+ from tvm.target import Target # pylint: disable=import-outside-toplevel
+
+ return Target(
+ {
+ "kind": "cuda",
+ "max_shared_memory_per_block": dev.max_shared_memory_per_block,
+ "max_threads_per_block": dev.max_threads_per_block,
+ "thread_warp_size": dev.warp_size,
+ "arch": "sm_" + dev.compute_version.replace(".", ""),
+ }
+ )
+
+
def register_backend():
"""Register CUDA-owned Python semantics."""
+ from tvm.target.detect_target import register_device_target_detector
from tvm.tirx.script.builder import ir as builder_ir # pylint:
disable=import-outside-toplevel
+ runtime_dir = Path(_LOADED_LIBS["tvm_runtime"]._name).resolve().parent
+ try:
+ # Runtime sidecars only need registration side effects; libtvm_runtime
is global.
+ _LOADED_LIBS["tvm_runtime_cuda"] = load_lib_ctypes(
+ package="tvm",
+ target_name="tvm_runtime_cuda",
+ extra_lib_paths=[runtime_dir],
+ mode="RTLD_LOCAL",
+ )
+ except (OSError, FileNotFoundError, RuntimeError):
+ pass
+ register_device_target_detector("cuda", _detect_target_from_device)
for name, namespace in script_namespaces().items():
builder_ir.register_script_namespace(name, namespace)
diff --git a/python/tvm/backend/hexagon/__init__.py
b/python/tvm/backend/hexagon/__init__.py
index 3852e36ccc..5d8aa202e5 100644
--- a/python/tvm/backend/hexagon/__init__.py
+++ b/python/tvm/backend/hexagon/__init__.py
@@ -17,12 +17,28 @@
"""Hexagon-owned backend hooks."""
from importlib import import_module
+from pathlib import Path
+
+from tvm_ffi.libinfo import load_lib_ctypes
+
+from tvm.base import _LOADED_LIBS
_LAZY_SUBMODULES = {"target_tags"}
def register_backend():
"""Register Hexagon-owned Python semantics."""
+ runtime_dir = Path(_LOADED_LIBS["tvm_runtime"]._name).resolve().parent
+ try:
+ # Runtime sidecars only need registration side effects; libtvm_runtime
is global.
+ _LOADED_LIBS["tvm_runtime_hexagon"] = load_lib_ctypes(
+ package="tvm",
+ target_name="tvm_runtime_hexagon",
+ extra_lib_paths=[runtime_dir],
+ mode="RTLD_LOCAL",
+ )
+ except (OSError, FileNotFoundError, RuntimeError):
+ pass
import_module(f"{__name__}.target_tags")
diff --git a/python/tvm/backend/__init__.py b/python/tvm/backend/loader.py
similarity index 92%
copy from python/tvm/backend/__init__.py
copy to python/tvm/backend/loader.py
index 2243f5e034..f2ffce7b60 100644
--- a/python/tvm/backend/__init__.py
+++ b/python/tvm/backend/loader.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Backend-owned Python modules and load hooks."""
+"""Backend loading and public alias support."""
from __future__ import annotations
@@ -22,22 +22,8 @@ import importlib
import importlib.util
import sys
import types
-from pkgutil import extend_path
from typing import Any
-__path__ = extend_path(__path__, __name__) # type: ignore[name-defined]
-
-_BUILTIN_BACKENDS = (
- "cuda",
- "metal",
- "rocm",
- "trn",
- "opencl",
- "vulkan",
- "webgpu",
- "hexagon",
- "adreno",
-)
_LOADED_BACKENDS: dict[str, Any] = {}
@@ -192,18 +178,10 @@ def load(name: str) -> None:
return None
-def load_all() -> None:
- """Load all in-tree backend Python hooks."""
-
- for name in _BUILTIN_BACKENDS:
- load(name)
- return None
-
-
def is_loaded(name: str) -> bool:
"""Return whether a backend has been loaded."""
return name in _LOADED_BACKENDS
-__all__ = ["is_loaded", "load", "load_all"]
+__all__ = ["is_loaded", "load"]
diff --git a/python/tvm/backend/metal/__init__.py
b/python/tvm/backend/metal/__init__.py
index d42806433f..54f99fbd8e 100644
--- a/python/tvm/backend/metal/__init__.py
+++ b/python/tvm/backend/metal/__init__.py
@@ -17,14 +17,45 @@
"""Metal-owned TIRx modules."""
from importlib import import_module
+from pathlib import Path
+
+from tvm_ffi.libinfo import load_lib_ctypes
+
+from tvm.base import _LOADED_LIBS
_LAZY_SUBMODULES = {"op", "script", "target_tags"}
+def _detect_target_from_device(dev):
+ from tvm.target import Target # pylint: disable=import-outside-toplevel
+
+ return Target(
+ {
+ "kind": "metal",
+ "max_shared_memory_per_block": 32768,
+ "max_threads_per_block": dev.max_threads_per_block,
+ "thread_warp_size": dev.warp_size,
+ }
+ )
+
+
def register_backend():
"""Register Metal-owned Python semantics."""
+ from tvm.target.detect_target import register_device_target_detector
from tvm.tirx.script.builder import ir as builder_ir # pylint:
disable=import-outside-toplevel
+ runtime_dir = Path(_LOADED_LIBS["tvm_runtime"]._name).resolve().parent
+ try:
+ # Runtime sidecars only need registration side effects; libtvm_runtime
is global.
+ _LOADED_LIBS["tvm_runtime_metal"] = load_lib_ctypes(
+ package="tvm",
+ target_name="tvm_runtime_metal",
+ extra_lib_paths=[runtime_dir],
+ mode="RTLD_LOCAL",
+ )
+ except (OSError, FileNotFoundError, RuntimeError):
+ pass
+ register_device_target_detector("metal", _detect_target_from_device)
for name, namespace in script_namespaces().items():
builder_ir.register_script_namespace(name, namespace)
import_module(f"{__name__}.target_tags")
diff --git a/python/tvm/backend/opencl/__init__.py
b/python/tvm/backend/opencl/__init__.py
index a80696e590..599d284cbd 100644
--- a/python/tvm/backend/opencl/__init__.py
+++ b/python/tvm/backend/opencl/__init__.py
@@ -16,9 +16,42 @@
# under the License.
"""OpenCL-owned backend hooks."""
+from pathlib import Path
+
+from tvm_ffi.libinfo import load_lib_ctypes
+
+from tvm.base import _LOADED_LIBS
+
+
+def _detect_target_from_device(dev):
+ from tvm.target import Target # pylint: disable=import-outside-toplevel
+
+ return Target(
+ {
+ "kind": "opencl",
+ "max_shared_memory_per_block": dev.max_shared_memory_per_block,
+ "max_threads_per_block": dev.max_threads_per_block,
+ "thread_warp_size": dev.warp_size,
+ }
+ )
+
def register_backend():
"""Register OpenCL-owned Python semantics."""
+ from tvm.target.detect_target import register_device_target_detector
+
+ runtime_dir = Path(_LOADED_LIBS["tvm_runtime"]._name).resolve().parent
+ try:
+ # Runtime sidecars only need registration side effects; libtvm_runtime
is global.
+ _LOADED_LIBS["tvm_runtime_opencl"] = load_lib_ctypes(
+ package="tvm",
+ target_name="tvm_runtime_opencl",
+ extra_lib_paths=[runtime_dir],
+ mode="RTLD_LOCAL",
+ )
+ except (OSError, FileNotFoundError, RuntimeError):
+ pass
+ register_device_target_detector("opencl", _detect_target_from_device)
return None
diff --git a/python/tvm/backend/rocm/__init__.py
b/python/tvm/backend/rocm/__init__.py
index d7574e974a..05b8e5de09 100644
--- a/python/tvm/backend/rocm/__init__.py
+++ b/python/tvm/backend/rocm/__init__.py
@@ -16,9 +16,43 @@
# under the License.
"""ROCm-owned TIRx modules."""
+from pathlib import Path
+
+from tvm_ffi.libinfo import load_lib_ctypes
+
+from tvm.base import _LOADED_LIBS
+
+
+def _detect_target_from_device(dev):
+ from tvm.target import Target # pylint: disable=import-outside-toplevel
+
+ return Target(
+ {
+ "kind": "rocm",
+ "mtriple": "amdgcn-amd-amdhsa-hcc",
+ "max_shared_memory_per_block": dev.max_shared_memory_per_block,
+ "max_threads_per_block": dev.max_threads_per_block,
+ "thread_warp_size": dev.warp_size,
+ }
+ )
+
def register_backend():
"""Register ROCm-owned Python semantics."""
+ from tvm.target.detect_target import register_device_target_detector
+
+ runtime_dir = Path(_LOADED_LIBS["tvm_runtime"]._name).resolve().parent
+ try:
+ # Runtime sidecars only need registration side effects; libtvm_runtime
is global.
+ _LOADED_LIBS["tvm_runtime_rocm"] = load_lib_ctypes(
+ package="tvm",
+ target_name="tvm_runtime_rocm",
+ extra_lib_paths=[runtime_dir],
+ mode="RTLD_LOCAL",
+ )
+ except (OSError, FileNotFoundError, RuntimeError):
+ pass
+ register_device_target_detector("rocm", _detect_target_from_device)
return None
diff --git a/python/tvm/backend/vulkan/__init__.py
b/python/tvm/backend/vulkan/__init__.py
index 343875aa8a..4977f7b433 100644
--- a/python/tvm/backend/vulkan/__init__.py
+++ b/python/tvm/backend/vulkan/__init__.py
@@ -16,9 +16,53 @@
# under the License.
"""Vulkan-owned backend hooks."""
+from pathlib import Path
+
+from tvm_ffi.libinfo import load_lib_ctypes
+
+from tvm.base import _LOADED_LIBS
+
+
+def _detect_target_from_device(dev):
+ from tvm import get_global_func # pylint: disable=import-outside-toplevel
+ from tvm.target import Target # pylint: disable=import-outside-toplevel
+
+ f_get_target_property =
get_global_func("device_api.vulkan.get_target_property")
+ return Target(
+ {
+ "kind": "vulkan",
+ "max_threads_per_block": dev.max_threads_per_block,
+ "max_shared_memory_per_block": dev.max_shared_memory_per_block,
+ "thread_warp_size": dev.warp_size,
+ "supports_float16": f_get_target_property(dev, "supports_float16"),
+ "supports_int8": f_get_target_property(dev, "supports_int8"),
+ "supports_int16": f_get_target_property(dev, "supports_int16"),
+ "supports_int64": f_get_target_property(dev, "supports_int64"),
+ "supports_8bit_buffer": f_get_target_property(dev,
"supports_8bit_buffer"),
+ "supports_16bit_buffer": f_get_target_property(dev,
"supports_16bit_buffer"),
+ "supports_storage_buffer_storage_class": f_get_target_property(
+ dev, "supports_storage_buffer_storage_class"
+ ),
+ }
+ )
+
def register_backend():
"""Register Vulkan-owned Python semantics."""
+ from tvm.target.detect_target import register_device_target_detector
+
+ runtime_dir = Path(_LOADED_LIBS["tvm_runtime"]._name).resolve().parent
+ try:
+ # Runtime sidecars only need registration side effects; libtvm_runtime
is global.
+ _LOADED_LIBS["tvm_runtime_vulkan"] = load_lib_ctypes(
+ package="tvm",
+ target_name="tvm_runtime_vulkan",
+ extra_lib_paths=[runtime_dir],
+ mode="RTLD_LOCAL",
+ )
+ except (OSError, FileNotFoundError, RuntimeError):
+ pass
+ register_device_target_detector("vulkan", _detect_target_from_device)
return None
diff --git a/python/tvm/base.py b/python/tvm/base.py
index 5c1e75566e..e850f9c214 100644
--- a/python/tvm/base.py
+++ b/python/tvm/base.py
@@ -19,7 +19,6 @@
"""Base library for TVM."""
import os
-from pathlib import Path
from tvm_ffi.libinfo import load_lib_ctypes
@@ -40,36 +39,12 @@ _RUNTIME_ONLY = os.environ.get("TVM_USE_RUNTIME_LIB") == "1"
_LOADED_LIBS = {}
-def load_backend_libs(runtime_lib_path: str) -> None:
- """Load each known backend runtime DSO into ``_LOADED_LIBS``; failures are
silent."""
- # Known per-backend runtime DSOs that, when present, are loaded with
- # RTLD_GLOBAL so their static initializers register the device backend.
- backend_runtime_libs = ["cuda", "vulkan", "opencl", "metal", "rocm",
"hexagon", "extra"]
- runtime_dir = Path(runtime_lib_path).resolve().parent
- for backend in backend_runtime_libs:
- target_name = f"tvm_runtime_{backend}"
- try:
- _LOADED_LIBS[target_name] = load_lib_ctypes(
- package="tvm",
- target_name=target_name,
- mode="RTLD_GLOBAL",
- extra_lib_paths=[runtime_dir],
- )
- except (OSError, FileNotFoundError, RuntimeError):
- pass
-
-
# runtime is loaded RTLD_GLOBAL to expose its symbols to subsequent loads;
# compiler is loaded RTLD_LOCAL.
_LOADED_LIBS["tvm_runtime"] = load_lib_ctypes(
"tvm", "tvm_runtime", "RTLD_GLOBAL",
extra_lib_paths=libinfo.package_lib_paths()
)
-# After libtvm_runtime.so is in the global symbol namespace, scan the same
-# directory for per-backend DSOs (libtvm_runtime_cuda.so, etc.) and load each
-# with RTLD_GLOBAL so their static initializers register device backends.
-load_backend_libs(_LOADED_LIBS["tvm_runtime"]._name)
-
if not _RUNTIME_ONLY:
try:
_LOADED_LIBS["tvm_compiler"] = load_lib_ctypes(
diff --git a/python/tvm/s_tir/meta_schedule/space_generator/__init__.py
b/python/tvm/s_tir/meta_schedule/space_generator/__init__.py
index 26ee689999..63f84d2948 100644
--- a/python/tvm/s_tir/meta_schedule/space_generator/__init__.py
+++ b/python/tvm/s_tir/meta_schedule/space_generator/__init__.py
@@ -25,5 +25,3 @@ from .post_order_apply import PostOrderApply
from .schedule_fn import ScheduleFn
from .space_generator import PySpaceGenerator, ScheduleFnType, SpaceGenerator,
create
from .space_generator_union import SpaceGeneratorUnion
-
-from ....target import x86
diff --git a/python/tvm/target/detect_target.py
b/python/tvm/target/detect_target.py
index 81accfed12..e7c434af61 100644
--- a/python/tvm/target/detect_target.py
+++ b/python/tvm/target/detect_target.py
@@ -16,79 +16,14 @@
# under the License.
"""Detect target."""
+from collections.abc import Callable
+
from tvm_ffi import get_global_func
from ..runtime import Device, device
from . import Target
-def _detect_metal(dev: Device) -> Target:
- return Target(
- {
- "kind": "metal",
- "max_shared_memory_per_block": 32768,
- "max_threads_per_block": dev.max_threads_per_block,
- "thread_warp_size": dev.warp_size,
- }
- )
-
-
-def _detect_cuda(dev: Device) -> Target:
- return Target(
- {
- "kind": "cuda",
- "max_shared_memory_per_block": dev.max_shared_memory_per_block,
- "max_threads_per_block": dev.max_threads_per_block,
- "thread_warp_size": dev.warp_size,
- "arch": "sm_" + dev.compute_version.replace(".", ""),
- }
- )
-
-
-def _detect_rocm(dev: Device) -> Target:
- return Target(
- {
- "kind": "rocm",
- "mtriple": "amdgcn-amd-amdhsa-hcc",
- "max_shared_memory_per_block": dev.max_shared_memory_per_block,
- "max_threads_per_block": dev.max_threads_per_block,
- "thread_warp_size": dev.warp_size,
- }
- )
-
-
-def _detect_opencl(dev: Device) -> Target:
- return Target(
- {
- "kind": "opencl",
- "max_shared_memory_per_block": dev.max_shared_memory_per_block,
- "max_threads_per_block": dev.max_threads_per_block,
- "thread_warp_size": dev.warp_size,
- }
- )
-
-
-def _detect_vulkan(dev: Device) -> Target:
- f_get_target_property =
get_global_func("device_api.vulkan.get_target_property")
- return Target(
- {
- "kind": "vulkan",
- "max_threads_per_block": dev.max_threads_per_block,
- "max_shared_memory_per_block": dev.max_shared_memory_per_block,
- "thread_warp_size": dev.warp_size,
- "supports_float16": f_get_target_property(dev, "supports_float16"),
- "supports_int8": f_get_target_property(dev, "supports_int8"),
- "supports_int16": f_get_target_property(dev, "supports_int16"),
- "supports_int64": f_get_target_property(dev, "supports_int64"),
- "supports_8bit_buffer": f_get_target_property(dev,
"supports_8bit_buffer"),
- "supports_16bit_buffer": f_get_target_property(dev,
"supports_16bit_buffer"),
- "supports_storage_buffer_storage_class": f_get_target_property(
- dev, "supports_storage_buffer_storage_class"
- ),
- }
- )
-
-
def _detect_cpu(dev: Device) -> Target: # pylint: disable=unused-argument
"""Detect the host CPU architecture."""
return Target(
@@ -106,6 +41,19 @@ def _detect_cpu(dev: Device) -> Target: # pylint:
disable=unused-argument
)
+SUPPORTED_DEVICE: dict[str, Callable[[Device], Target]] = {
+ "cpu": _detect_cpu,
+}
+
+# Backward-compatible alias for the previous private module-level map.
+SUPPORT_DEVICE = SUPPORTED_DEVICE
+
+
+def register_device_target_detector(device_type: str, detector:
Callable[[Device], Target]) -> None:
+ """Register target detection for a runtime device type."""
+ SUPPORTED_DEVICE[device_type] = detector
+
+
def detect_target_from_device(dev: str | Device) -> Target:
"""Detects Target associated with the given device. If the device does not
exist,
there will be an Error.
@@ -114,7 +62,7 @@ def detect_target_from_device(dev: str | Device) -> Target:
----------
dev : Union[str, Device]
The device to detect the target for.
- Supported device types: ["cuda", "metal", "rocm", "vulkan", "opencl"]
+ Supported device types are registered by backend hooks.
Returns
-------
@@ -124,24 +72,14 @@ def detect_target_from_device(dev: str | Device) -> Target:
if isinstance(dev, str):
dev = device(dev)
device_type = Device._DEVICE_TYPE_TO_NAME[dev.dlpack_device_type()]
- if device_type not in SUPPORT_DEVICE:
+ if device_type not in SUPPORTED_DEVICE:
raise ValueError(
f"Auto detection for device `{device_type}` is not supported. "
- f"Currently only supports: {SUPPORT_DEVICE.keys()}"
+ f"Currently only supports: {SUPPORTED_DEVICE.keys()}"
)
if not dev.exist:
raise ValueError(
f"Cannot detect device `{dev}`. Please make sure the device and
its driver "
"is installed properly, and TVM is compiled with the driver"
)
- return SUPPORT_DEVICE[device_type](dev)
-
-
-SUPPORT_DEVICE = {
- "cpu": _detect_cpu,
- "cuda": _detect_cuda,
- "metal": _detect_metal,
- "vulkan": _detect_vulkan,
- "rocm": _detect_rocm,
- "opencl": _detect_opencl,
-}
+ return SUPPORTED_DEVICE[device_type](dev)
diff --git a/python/tvm/target/x86.py b/python/tvm/target/x86.py
deleted file mode 100644
index 80399c749b..0000000000
--- a/python/tvm/target/x86.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Common x86 related utilities"""
-
-from tvm_ffi import register_global_func
-
-from .codegen import target_has_features
-
-
-@register_global_func("tvm.topi.x86.utils.get_simd_32bit_lanes")
-def get_simd_32bit_lanes():
- """X86 SIMD optimal vector length lookup.
- Parameters
- ----------
- Returns
- -------
- vec_len : int
- The optimal vector length of CPU from the global context target.
- """
- vec_len = 4
- if target_has_features(["avx512bw", "avx512f"]):
- vec_len = 16
- elif target_has_features("avx2"):
- vec_len = 8
- return vec_len
diff --git a/tests/python/tirx/test_op_namespace_cleanup.py
b/tests/python/tirx/test_op_namespace_cleanup.py
index 40965d339b..29cf3ba972 100644
--- a/tests/python/tirx/test_op_namespace_cleanup.py
+++ b/tests/python/tirx/test_op_namespace_cleanup.py
@@ -256,7 +256,7 @@ def
test_backend_load_updates_tirx_alias_and_script_facades(monkeypatch):
monkeypatch.setitem(sys.modules, op_module.__name__, op_module)
sys.modules.pop(public_module_name, None)
sys.modules.pop(public_op_module_name, None)
- tvm.backend._LOADED_BACKENDS.pop(backend_name, None)
+ tvm.backend.loader._LOADED_BACKENDS.pop(backend_name, None)
if hasattr(tvm.tirx, backend_name):
delattr(tvm.tirx, backend_name)
@@ -280,7 +280,7 @@ def
test_backend_load_updates_tirx_alias_and_script_facades(monkeypatch):
assert getattr(parser, namespace_name) is namespace
assert getattr(T, namespace_name) is namespace
finally:
- tvm.backend._LOADED_BACKENDS.pop(backend_name, None)
+ tvm.backend.loader._LOADED_BACKENDS.pop(backend_name, None)
if hasattr(tvm.tirx, backend_name):
delattr(tvm.tirx, backend_name)
sys.modules.pop(public_module_name, None)