This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 283fd19683 [REFACTOR][TARGET] Further cleanup target python api
(#18793)
283fd19683 is described below
commit 283fd196838710dfc11bc7605c26d1305279f06e
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Feb 17 21:39:27 2026 -0500
[REFACTOR][TARGET] Further cleanup target python api (#18793)
This PR cleans up the target python api.
- Removes the indirections of attribute exposure
- Move tag registry to python so it is easily configurable
- Remove legacy constructors in favor of tags
---
include/tvm/target/tag.h | 7 +
python/tvm/contrib/hexagon/pytest_plugin.py | 2 +-
python/tvm/contrib/nvcc.py | 7 +-
python/tvm/contrib/thrust.py | 9 +-
python/tvm/relax/frontend/nn/llm/kv_cache.py | 2 +-
python/tvm/relax/frontend/torch/dynamo.py | 2 +-
python/tvm/s_tir/dlight/gpu/gemv.py | 4 +-
python/tvm/s_tir/dlight/gpu/low_batch_gemv.py | 4 +-
python/tvm/s_tir/dlight/gpu/matmul.py | 2 +-
.../schedule/cuda/layout_transform.py | 4 +-
python/tvm/target/__init__.py | 68 +-
python/tvm/target/tag.py | 67 +-
.../__init__.py} | 22 +-
python/tvm/target/tag_registry/adreno.py | 63 ++
python/tvm/target/tag_registry/arm_cpu.py | 116 ++++
.../thrust.py => target/tag_registry/aws_cpu.py} | 41 +-
python/tvm/target/tag_registry/cuda.py | 373 +++++++++++
python/tvm/target/tag_registry/hexagon.py | 55 ++
.../metal.py} | 31 +-
.../target/{tag.py => tag_registry/registry.py} | 20 +-
python/tvm/target/tag_registry/riscv_cpu.py | 72 +++
python/tvm/target/target.py | 702 +--------------------
python/tvm/tir/build.py | 6 +-
python/tvm/topi/gpu/scan.py | 4 +-
python/tvm/topi/gpu/sort.py | 6 +-
python/tvm/topi/math.py | 6 +-
src/target/tag.cc | 409 +-----------
src/target/target.cc | 18 +
tests/python/codegen/test_target_codegen_cuda.py | 4 +-
.../python/codegen/test_target_codegen_hexagon.py | 16 +-
.../python/contrib/test_hexagon/infrastructure.py | 22 +-
.../test_hexagon/test_async_dma_pipeline.py | 2 +-
.../contrib/test_hexagon/test_dma_builtin.py | 2 +-
.../test_relax_2d_buffer_allocation.py | 2 +-
.../contrib/test_hexagon/test_relax_integration.py | 4 +-
tests/python/relax/backend/clml/utils.py | 4 +-
tests/python/relax/texture/adreno_utils.py | 11 +-
...ule_postproc_disallow_async_strided_mem_copy.py | 2 +-
...est_meta_schedule_postproc_verify_vtcm_limit.py | 7 +-
.../test_meta_schedule_schedule_rule_mlt.py | 4 +-
tests/python/target/test_target_target.py | 169 ++---
tests/python/target/test_x86_features.py | 4 +-
42 files changed, 945 insertions(+), 1430 deletions(-)
diff --git a/include/tvm/target/tag.h b/include/tvm/target/tag.h
index d3fae83e06..b8de3fffba 100644
--- a/include/tvm/target/tag.h
+++ b/include/tvm/target/tag.h
@@ -81,6 +81,13 @@ class TargetTag : public ObjectRef {
* \return A dictionary that maps tag name to the concrete target it
corresponds to
*/
TVM_DLL static ffi::Map<ffi::String, Target> ListTags();
+ /*!
+ * \brief Retrieve the raw config dict for a target tag
+ * \param target_tag_name Name of the target tag
+ * \return The config dict if the tag exists, nullopt otherwise
+ */
+ TVM_DLL static ffi::Optional<ffi::Map<ffi::String, Any>> GetConfig(
+ const ffi::String& target_tag_name);
/*!
* \brief Add a tag into the registry
* \param name Name of the tag
diff --git a/python/tvm/contrib/hexagon/pytest_plugin.py
b/python/tvm/contrib/hexagon/pytest_plugin.py
index 3770e8d781..4cc7b94d6c 100644
--- a/python/tvm/contrib/hexagon/pytest_plugin.py
+++ b/python/tvm/contrib/hexagon/pytest_plugin.py
@@ -317,7 +317,7 @@ aot_host_target =
tvm.testing.parameter(HEXAGON_AOT_LLVM_TARGET)
@tvm.testing.fixture
def aot_target(aot_host_target):
if aot_host_target == "c":
- yield tvm.target.hexagon("v68")
+ yield tvm.target.Target({"kind": "hexagon", "mtriple": "hexagon",
"mcpu": "hexagonv68"})
elif isinstance(aot_host_target, dict) and aot_host_target.get("kind") ==
"llvm":
yield aot_host_target
elif isinstance(aot_host_target, str) and
aot_host_target.startswith("llvm"):
diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index 7706f63973..e608bc2810 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -894,10 +894,11 @@ def get_target_compute_version(target=None):
# 1. input target object
# 2. Target.current()
target = target or Target.current()
- if target and target.arch:
- arch = target.arch.split("_")[1]
+ target_arch = str(target.attrs.get("arch", "")) if target else ""
+ if target_arch:
+ arch = target_arch.split("_")[1]
if len(arch) < 2:
- raise ValueError(f"The arch is not expected {target.arch}")
+ raise ValueError(f"The arch is not expected {target_arch}")
if arch[-1].isalpha():
# This is for arch like "sm_90a"
suffix = arch[-1]
diff --git a/python/tvm/contrib/thrust.py b/python/tvm/contrib/thrust.py
index 8cf7c59fad..b635fd1045 100644
--- a/python/tvm/contrib/thrust.py
+++ b/python/tvm/contrib/thrust.py
@@ -21,7 +21,10 @@ from tvm_ffi import get_global_func
def maybe_warn(target, func_name):
- if "thrust" in target.libs and get_global_func(func_name,
allow_missing=True) is None:
+ if (
+ "thrust" in list(target.attrs.get("libs", []))
+ and get_global_func(func_name, allow_missing=True) is None
+ ):
logging.warning("thrust is requested but TVM is not built with
thrust.")
@@ -29,7 +32,7 @@ def can_use_thrust(target, func_name):
maybe_warn(target, func_name)
return (
target.kind.name in ["cuda", "nvptx"]
- and "thrust" in target.libs
+ and "thrust" in list(target.attrs.get("libs", []))
and get_global_func(func_name, allow_missing=True)
)
@@ -38,6 +41,6 @@ def can_use_rocthrust(target, func_name):
maybe_warn(target, func_name)
return (
target.kind.name == "rocm"
- and "thrust" in target.libs
+ and "thrust" in list(target.attrs.get("libs", []))
and get_global_func(func_name, allow_missing=True)
)
diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index 4d3ce0cbc4..1b57b4a8b6 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -48,7 +48,7 @@ def get_max_num_threads_per_block(target: Target) -> int:
max(max_num_threads, max_threads_per_block); if latter does not exist,
return max_num_threads.
We add this method since some targets have both fields and
`max_threads_per_block` is larger.
"""
- max_num_threads = target.max_num_threads
+ max_num_threads = int(target.attrs["max_num_threads"])
max_threads_per_block = target.attrs.get("max_threads_per_block", None)
if max_threads_per_block is None:
return max_num_threads
diff --git a/python/tvm/relax/frontend/torch/dynamo.py
b/python/tvm/relax/frontend/torch/dynamo.py
index 668543fc8a..8d62190769 100644
--- a/python/tvm/relax/frontend/torch/dynamo.py
+++ b/python/tvm/relax/frontend/torch/dynamo.py
@@ -103,7 +103,7 @@ def relax_dynamo(pipeline: Optional[tvm.transform.Pass] =
None):
if device.type == "cuda":
dev = tvm.cuda(device.index)
- target = tvm.target.cuda()
+ target = tvm.target.Target("cuda")
else:
dev = tvm.cpu(0)
target = tvm.target.Target(llvm_target())
diff --git a/python/tvm/s_tir/dlight/gpu/gemv.py
b/python/tvm/s_tir/dlight/gpu/gemv.py
index 78f1fd67e2..ded494471c 100644
--- a/python/tvm/s_tir/dlight/gpu/gemv.py
+++ b/python/tvm/s_tir/dlight/gpu/gemv.py
@@ -158,7 +158,7 @@ class GEMV(GPUScheduleRule):
LOAD_V_SHARED = (
LOAD_V_SHARED
and isinstance(shared_mem_usage, tir.IntImm)
- and shared_mem_usage.value <=
target.max_shared_memory_per_block
+ and shared_mem_usage.value <=
int(target.attrs["max_shared_memory_per_block"])
)
# vectorize load A
@@ -390,7 +390,7 @@ class GEMV(GPUScheduleRule):
UNROLL = 64
TS, TR = 1, 64
- while TS * TR > target.max_num_threads:
+ while TS * TR > int(target.attrs["max_num_threads"]):
if TS > 1:
TS //= 2
else:
diff --git a/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py
b/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py
index 931005e7a1..5d29714298 100644
--- a/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py
+++ b/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py
@@ -355,7 +355,7 @@ class LowBatchGEMV(GPUScheduleRule):
LOAD_V_SHARED = (
LOAD_V_SHARED
and isinstance(shared_mem_usage, tir.IntImm)
- and shared_mem_usage.value <=
target.max_shared_memory_per_block
+ and shared_mem_usage.value <=
int(target.attrs["max_shared_memory_per_block"])
)
# vectorize load A
@@ -571,7 +571,7 @@ class LowBatchGEMV(GPUScheduleRule):
if not isinstance(len_s, int):
TS, TR = 1, 64
- while TS * TR > target.max_num_threads:
+ while TS * TR > int(target.attrs["max_num_threads"]):
if TS > 1:
TS //= 2
else:
diff --git a/python/tvm/s_tir/dlight/gpu/matmul.py
b/python/tvm/s_tir/dlight/gpu/matmul.py
index 223b216ed6..58ae0c73e1 100644
--- a/python/tvm/s_tir/dlight/gpu/matmul.py
+++ b/python/tvm/s_tir/dlight/gpu/matmul.py
@@ -1020,7 +1020,7 @@ class Matmul(GPUScheduleRule):
# tensorization rule will not be applied.
minimal_tensorize_threshold = 64
block_stmt = sch.get(main_block)
- if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70:
+ if target.kind.name == "cuda" and
check_sm_version(str(target.attrs.get("arch", ""))) >= 70:
apply_tensorization: bool = True
# the batch dimension is not taken into consideration.
for item_var in block_stmt.iter_vars[1:]:
diff --git a/python/tvm/s_tir/meta_schedule/schedule/cuda/layout_transform.py
b/python/tvm/s_tir/meta_schedule/schedule/cuda/layout_transform.py
index 03879a1bd6..93e0bec754 100644
--- a/python/tvm/s_tir/meta_schedule/schedule/cuda/layout_transform.py
+++ b/python/tvm/s_tir/meta_schedule/schedule/cuda/layout_transform.py
@@ -496,8 +496,8 @@ def get_max_tile_size() -> int:
"""
max_tile_size = 32
cur_target = tvm.target.Target.current()
- if cur_target is not None and hasattr(cur_target, "thread_warp_size"):
- max_tile_size = int(cur_target.thread_warp_size)
+ if cur_target is not None and "thread_warp_size" in cur_target.attrs:
+ max_tile_size = int(cur_target.attrs["thread_warp_size"])
return max_tile_size
diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py
index 9288eb3f97..6a0e2b4f84 100644
--- a/python/tvm/target/__init__.py
+++ b/python/tvm/target/__init__.py
@@ -14,66 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Target description and codgen module.
+"""Target description and codegen module.
-TVM's target string is in format ``<target_kind> [-option=value]...``.
+TVM uses JSON-based target configuration. Targets can be constructed via:
-Note
-----
-The list of options include:
+- A config dictionary: ``Target({"kind": "cuda", "arch": "sm_80"})``
+- A tag name: ``Target("nvidia/nvidia-a100")``
+- A tag with overrides: ``Target({"tag": "nvidia/nvidia-a100",
"l2_cache_size_bytes": 12345})``
+- A kind name: ``Target("cuda")``
-- **-device=<device name>**
+Use ``target.attrs["key"]`` to access target attributes such as
+``"arch"``, ``"max_num_threads"``, ``"mcpu"``, ``"libs"``, etc.
- The device name.
-
-- **-mtriple=<target triple>**
-
- Specify the target triple, which is useful for cross
- compilation.
-
-- **-mcpu=<cpuname>**
-
- Specify a specific chip in the current architecture to
- generate code for. By default this is infered from the
- target triple and autodetected to the current architecture.
-
-- **-mattr=a1,+a2,-a3,...**
-
- Override or control specific attributes of the target,
- such as whether SIMD operations are enabled or not. The
- default set of attributes is set by the current CPU.
-
-- **-mabi=<abi>**
-
- Generate code for the specified ABI, for example "lp64d".
-
-- **-system-lib**
-
- Build TVM system library module. System lib is a global module that contains
- self registered functions in program startup. User can get the module using
- `tvm.runtime.system_lib`.
- It is useful in environments where dynamic loading api like dlopen is
banned.
- The system lib will be available as long as the result code is linked by
the program.
-
-We can use :py:func:`tvm.target.Target` to create a tvm.target.Target from the
target string.
-We can also use other specific function in this module to create specific
targets.
+Use :py:func:`tvm.target.list_tags` to list all available target tags,
+and :py:func:`tvm.target.register_tag` to register new tags.
"""
-from .target import Target, create, TargetKind
-from .target import (
- cuda,
- rocm,
- mali,
- intel_graphics,
- arm_cpu,
- rasp,
- bifrost,
- riscv_cpu,
- hexagon,
- stm32,
- adreno,
-)
+from .target import Target, TargetKind
from .virtual_device import VirtualDevice
-from .compilation_config import make_compilation_config
-from .tag import list_tags
+from .tag import list_tags, register_tag
from . import datatype
from . import codegen
+from . import tag_registry # noqa: F401 -- registers tags on import
diff --git a/python/tvm/target/tag.py b/python/tvm/target/tag.py
index 0cb2b97e15..8e277c67be 100644
--- a/python/tvm/target/tag.py
+++ b/python/tvm/target/tag.py
@@ -14,68 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Target tags"""
-from typing import Any, Dict, Optional
-from . import _ffi_api
-from .target import Target
+"""Target tags -- re-exported from tag_registry.registry."""
+from .tag_registry.registry import list_tags, register_tag
-
-def list_tags() -> Optional[Dict[str, Target]]:
- """Returns a dict of tags, which maps each tag name to its corresponding
target.
-
- Returns
- -------
- tag_dict : Optional[Dict[str, Target]]
- The dict of tags mapping each tag name to its corresponding target.
- None if TVM is built in runtime-only mode.
- """
- if hasattr(_ffi_api, "TargetTagListTags"):
- return _ffi_api.TargetTagListTags()
- return None
-
-
-def register_tag(name: str, config: Dict[str, Any], override: bool = False) ->
Optional[Target]:
- """Add a user-defined tag into the target tag registry.
-
- Parameters
- ----------
- name: str
- Name of the target, e.g. "nvidia/gtx1080ti"
- config : Dict[str, Any]
- The config dict used to create the target
- override: bool
- A boolean flag indicating if overriding existing tags are allowed.
- If False and the tag has been registered already, an exception will be
thrown.
-
- Returns
- -------
- target : Optional[Target]
- The target corresponding to the tag
- None if TVM is built in runtime-only mode.
-
- Examples
- --------
- .. code-block:: python
-
- register_tag("nvidia/gtx1080ti", config={
- "kind": "cuda",
- "arch": "sm_61",
- })
- """
- if hasattr(_ffi_api, "TargetTagAddTag"):
- return _ffi_api.TargetTagAddTag(name, config, override)
- return None
-
-
-# We purposely maintain all tags in the C++ side to support pure C++ use cases,
-# and the Python API is only used for fast prototyping.
-register_tag(
- "nvidia/gtx1080ti",
- config={
- "kind": "cuda",
- "arch": "sm_61",
- },
-)
-
-# To check the correctness of all registered tags, the call is made in library
loading time.
-list_tags()
+__all__ = ["list_tags", "register_tag"]
diff --git a/python/tvm/target/compilation_config.py
b/python/tvm/target/tag_registry/__init__.py
similarity index 61%
copy from python/tvm/target/compilation_config.py
copy to python/tvm/target/tag_registry/__init__.py
index 116f1dd8e9..34fcc1f7de 100644
--- a/python/tvm/target/compilation_config.py
+++ b/python/tvm/target/tag_registry/__init__.py
@@ -14,14 +14,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Python bindings for creating CompilationConfigs."""
-import tvm
-from . import _ffi_api
+"""Python-side target tag registry.
+Importing this package registers all Python-defined target tags.
+"""
+from . import registry # noqa: F401
+from . import cuda # noqa: F401
+from . import arm_cpu # noqa: F401
+from . import riscv_cpu # noqa: F401
+from . import aws_cpu # noqa: F401
+from . import metal # noqa: F401
+from . import hexagon # noqa: F401
+from . import adreno # noqa: F401
-def make_compilation_config(ctxt, target, target_host=None):
- """Returns a CompilationConfig appropriate for target and target_host,
using the same
- representation conventions as for the standard build interfaces. Intended
only for unit
- testing."""
- raw_targets = tvm.target.Target.canon_multi_target_and_host(target,
target_host)
- return _ffi_api.MakeCompilationConfig(ctxt, raw_targets)
+# Validate all tags at import time
+registry.list_tags()
diff --git a/python/tvm/target/tag_registry/adreno.py
b/python/tvm/target/tag_registry/adreno.py
new file mode 100644
index 0000000000..92e065e68c
--- /dev/null
+++ b/python/tvm/target/tag_registry/adreno.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Qualcomm Adreno GPU target tags."""
+from .registry import register_tag
+
+register_tag(
+ "qcom/adreno-opencl",
+ {
+ "kind": "opencl",
+ "device": "adreno",
+ "keys": ["adreno", "opencl", "gpu"],
+ },
+)
+
+register_tag(
+ "qcom/adreno-opencl-clml",
+ {
+ "kind": "opencl",
+ "device": "adreno",
+ "keys": ["adreno", "opencl", "gpu", "clml"],
+ },
+)
+
+register_tag(
+ "qcom/adreno-opencl-texture",
+ {
+ "kind": "opencl",
+ "device": "adreno",
+ "keys": ["adreno", "opencl", "gpu", "texture"],
+ },
+)
+
+register_tag(
+ "qcom/adreno-vulkan",
+ {
+ "kind": "vulkan",
+ "device": "adreno",
+ "keys": ["adreno", "vulkan", "gpu"],
+ },
+)
+
+register_tag(
+ "qcom/adreno-vulkan-texture",
+ {
+ "kind": "vulkan",
+ "device": "adreno",
+ "keys": ["adreno", "vulkan", "gpu", "texture"],
+ },
+)
diff --git a/python/tvm/target/tag_registry/arm_cpu.py
b/python/tvm/target/tag_registry/arm_cpu.py
new file mode 100644
index 0000000000..770a7e56e0
--- /dev/null
+++ b/python/tvm/target/tag_registry/arm_cpu.py
@@ -0,0 +1,116 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""ARM CPU target tags."""
+from .registry import register_tag
+
+
+# ---------- Raspberry Pi (self-hosted, with host config) ----------
+register_tag(
+ "raspberry-pi/4b-aarch64",
+ {
+ "kind": "llvm",
+ "mtriple": "aarch64-linux-gnu",
+ "mcpu": "cortex-a72",
+ "mattr": ["+neon"],
+ "num-cores": 4,
+ "host": {
+ "kind": "llvm",
+ "mtriple": "aarch64-linux-gnu",
+ "mcpu": "cortex-a72",
+ "mattr": ["+neon"],
+ "num-cores": 4,
+ },
+ },
+)
+
+
+# ---------- ARM boards ----------
+def _register_arm_tag(name, config):
+ base = {"kind": "llvm", "keys": ["arm_cpu", "cpu"], "device": "arm_cpu"}
+ base.update(config)
+ register_tag(name, base)
+
+
+_register_arm_tag(
+ "arm/pixel2",
+ {"model": "snapdragon835", "mtriple": "arm64-linux-android", "mattr":
["+neon"]},
+)
+_register_arm_tag(
+ "arm/mate10",
+ {"model": "kirin970", "mtriple": "arm64-linux-android", "mattr":
["+neon"]},
+)
+_register_arm_tag(
+ "arm/rasp3b",
+ {"model": "bcm2837", "mtriple": "armv7l-linux-gnueabihf", "mattr":
["+neon"]},
+)
+_register_arm_tag(
+ "arm/rasp4b",
+ {
+ "model": "bcm2711",
+ "mtriple": "armv8l-linux-gnueabihf",
+ "mattr": ["+neon"],
+ "mcpu": "cortex-a72",
+ },
+)
+_register_arm_tag(
+ "arm/rasp4b64",
+ {
+ "model": "bcm2711",
+ "mtriple": "aarch64-linux-gnu",
+ "mattr": ["+neon"],
+ "mcpu": "cortex-a72",
+ },
+)
+_register_arm_tag(
+ "arm/rk3399",
+ {"model": "rk3399", "mtriple": "aarch64-linux-gnu", "mattr": ["+neon"]},
+)
+_register_arm_tag(
+ "arm/pynq",
+ {"model": "pynq", "mtriple": "armv7a-linux-eabi", "mattr": ["+neon"]},
+)
+_register_arm_tag(
+ "arm/ultra96",
+ {"model": "ultra96", "mtriple": "aarch64-linux-gnu", "mattr": ["+neon"]},
+)
+_register_arm_tag(
+ "arm/beagleai",
+ {
+ "model": "beagleai",
+ "mtriple": "armv7a-linux-gnueabihf",
+ "mattr": ["+neon", "+vfp4", "+thumb2"],
+ "mcpu": "cortex-a15",
+ },
+)
+_register_arm_tag(
+ "arm/stm32mp1",
+ {
+ "model": "stm32mp1",
+ "mtriple": "armv7a-linux-gnueabihf",
+ "mattr": ["+neon", "+vfp4", "+thumb2"],
+ "mcpu": "cortex-a7",
+ },
+)
+_register_arm_tag(
+ "arm/thunderx",
+ {
+ "model": "thunderx",
+ "mtriple": "aarch64-linux-gnu",
+ "mattr": ["+neon", "+crc", "+lse"],
+ "mcpu": "thunderxt88",
+ },
+)
diff --git a/python/tvm/contrib/thrust.py
b/python/tvm/target/tag_registry/aws_cpu.py
similarity index 50%
copy from python/tvm/contrib/thrust.py
copy to python/tvm/target/tag_registry/aws_cpu.py
index 8cf7c59fad..3fb62a172e 100644
--- a/python/tvm/contrib/thrust.py
+++ b/python/tvm/target/tag_registry/aws_cpu.py
@@ -14,30 +14,27 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Utilities for thrust"""
-import logging
+"""AWS CPU instance target tags."""
+from .registry import register_tag
-from tvm_ffi import get_global_func
-
-def maybe_warn(target, func_name):
- if "thrust" in target.libs and get_global_func(func_name,
allow_missing=True) is None:
- logging.warning("thrust is requested but TVM is not built with
thrust.")
-
-
-def can_use_thrust(target, func_name):
- maybe_warn(target, func_name)
- return (
- target.kind.name in ["cuda", "nvptx"]
- and "thrust" in target.libs
- and get_global_func(func_name, allow_missing=True)
+def _register_aws_c5(name, cores, arch):
+ register_tag(
+ name,
+ {
+ "kind": "llvm",
+ "keys": ["x86", "cpu"],
+ "mcpu": arch,
+ "num-cores": cores,
+ },
)
-def can_use_rocthrust(target, func_name):
- maybe_warn(target, func_name)
- return (
- target.kind.name == "rocm"
- and "thrust" in target.libs
- and get_global_func(func_name, allow_missing=True)
- )
+_register_aws_c5("aws/cpu/c5.large", 1, "skylake-avx512")
+_register_aws_c5("aws/cpu/c5.xlarge", 2, "skylake-avx512")
+_register_aws_c5("aws/cpu/c5.2xlarge", 4, "skylake-avx512")
+_register_aws_c5("aws/cpu/c5.4xlarge", 8, "skylake-avx512")
+_register_aws_c5("aws/cpu/c5.9xlarge", 18, "skylake-avx512")
+_register_aws_c5("aws/cpu/c5.12xlarge", 24, "cascadelake")
+_register_aws_c5("aws/cpu/c5.18xlarge", 36, "skylake-avx512")
+_register_aws_c5("aws/cpu/c5.24xlarge", 48, "cascadelake")
diff --git a/python/tvm/target/tag_registry/cuda.py
b/python/tvm/target/tag_registry/cuda.py
new file mode 100644
index 0000000000..bfa36594fa
--- /dev/null
+++ b/python/tvm/target/tag_registry/cuda.py
@@ -0,0 +1,373 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""NVIDIA CUDA target tags."""
+from .registry import register_tag
+
+
+def _register_cuda_tag(name, arch, shared_mem=49152, regs=65536, **extra):
+ config = {
+ "kind": "cuda",
+ "keys": ["cuda", "gpu"],
+ "arch": arch,
+ "max_shared_memory_per_block": shared_mem,
+ "max_threads_per_block": 1024,
+ "thread_warp_size": 32,
+ "registers_per_block": regs,
+ }
+ config.update(extra)
+ register_tag(name, config)
+
+
+def _register_jetson_tag(name, arch, mcpu, num_cores, regs=65536):
+ register_tag(
+ name,
+ {
+ "kind": "cuda",
+ "arch": arch,
+ "max_shared_memory_per_block": 49152,
+ "max_threads_per_block": 1024,
+ "thread_warp_size": 32,
+ "registers_per_block": regs,
+ "host": {
+ "kind": "llvm",
+ "mtriple": "aarch64-linux-gnu",
+ "mcpu": mcpu,
+ "num-cores": num_cores,
+ },
+ },
+ )
+
+
+# =====================================================================
+# Data center / Tesla GPUs
+# =====================================================================
+_register_cuda_tag("nvidia/nvidia-a100", "sm_80", l2_cache_size_bytes=41943040)
+_register_cuda_tag("nvidia/nvidia-h100", "sm_90a",
l2_cache_size_bytes=52428800)
+_register_cuda_tag("nvidia/nvidia-b100", "sm_100a",
l2_cache_size_bytes=52428800)
+_register_cuda_tag("nvidia/nvidia-a40", "sm_86")
+_register_cuda_tag("nvidia/nvidia-a30", "sm_80")
+_register_cuda_tag("nvidia/nvidia-a10", "sm_86")
+_register_cuda_tag("nvidia/nvidia-a10g", "sm_86")
+_register_cuda_tag("nvidia/nvidia-a16", "sm_86")
+_register_cuda_tag("nvidia/nvidia-a2", "sm_86")
+_register_cuda_tag("nvidia/nvidia-t4", "sm_75")
+_register_cuda_tag("nvidia/nvidia-v100", "sm_70")
+_register_cuda_tag("nvidia/tesla-p100", "sm_60")
+_register_cuda_tag("nvidia/tesla-p40", "sm_61")
+_register_cuda_tag("nvidia/tesla-p4", "sm_61")
+_register_cuda_tag("nvidia/tesla-m60", "sm_52")
+_register_cuda_tag("nvidia/tesla-m40", "sm_52")
+_register_cuda_tag("nvidia/tesla-k80", "sm_37")
+_register_cuda_tag("nvidia/tesla-k40", "sm_35")
+_register_cuda_tag("nvidia/tesla-k20", "sm_35")
+_register_cuda_tag("nvidia/tesla-k10", "sm_30")
+_register_cuda_tag("nvidia/tesla-c2075", "sm_20", regs=32768)
+_register_cuda_tag("nvidia/tesla-c2050", "sm_20", regs=32768)
+_register_cuda_tag("nvidia/tesla-c2070", "sm_20", regs=32768)
+
+# =====================================================================
+# Quadro / RTX professional desktop GPUs
+# =====================================================================
+_register_cuda_tag("nvidia/rtx-a6000", "sm_86")
+_register_cuda_tag("nvidia/quadro-rtx-8000", "sm_75")
+_register_cuda_tag("nvidia/quadro-rtx-6000", "sm_75")
+_register_cuda_tag("nvidia/quadro-rtx-5000", "sm_75")
+_register_cuda_tag("nvidia/quadro-rtx-4000", "sm_75")
+_register_cuda_tag("nvidia/quadro-gv100", "sm_70")
+_register_cuda_tag("nvidia/quadro-gp100", "sm_60")
+_register_cuda_tag("nvidia/quadro-p6000", "sm_61")
+_register_cuda_tag("nvidia/quadro-p5000", "sm_61")
+_register_cuda_tag("nvidia/quadro-p4000", "sm_61")
+_register_cuda_tag("nvidia/quadro-p2200", "sm_61")
+_register_cuda_tag("nvidia/quadro-p2000", "sm_61")
+_register_cuda_tag("nvidia/quadro-p1000", "sm_61")
+_register_cuda_tag("nvidia/quadro-p620", "sm_61")
+_register_cuda_tag("nvidia/quadro-p600", "sm_61")
+_register_cuda_tag("nvidia/quadro-p400", "sm_61")
+_register_cuda_tag("nvidia/quadro-m6000-24gb", "sm_52")
+_register_cuda_tag("nvidia/quadro-m6000", "sm_52")
+_register_cuda_tag("nvidia/quadro-k6000", "sm_35")
+_register_cuda_tag("nvidia/quadro-m5000", "sm_52")
+_register_cuda_tag("nvidia/quadro-k5200", "sm_35")
+_register_cuda_tag("nvidia/quadro-k5000", "sm_30")
+_register_cuda_tag("nvidia/quadro-m4000", "sm_52")
+_register_cuda_tag("nvidia/quadro-k4200", "sm_30")
+_register_cuda_tag("nvidia/quadro-k4000", "sm_30")
+_register_cuda_tag("nvidia/quadro-m2000", "sm_52")
+_register_cuda_tag("nvidia/quadro-k2200", "sm_50")
+_register_cuda_tag("nvidia/quadro-k2000", "sm_30")
+_register_cuda_tag("nvidia/quadro-k2000d", "sm_30")
+_register_cuda_tag("nvidia/quadro-k1200", "sm_50")
+_register_cuda_tag("nvidia/quadro-k620", "sm_50")
+_register_cuda_tag("nvidia/quadro-k600", "sm_30")
+_register_cuda_tag("nvidia/quadro-k420", "sm_30")
+_register_cuda_tag("nvidia/quadro-410", "sm_30")
+_register_cuda_tag("nvidia/quadro-plex-7000", "sm_20", regs=32768)
+
+# =====================================================================
+# Quadro / RTX professional mobile GPUs (Turing)
+# =====================================================================
+_register_cuda_tag("nvidia/rtx-5000", "sm_75")
+_register_cuda_tag("nvidia/rtx-4000", "sm_75")
+_register_cuda_tag("nvidia/rtx-3000", "sm_75")
+_register_cuda_tag("nvidia/t2000", "sm_75")
+_register_cuda_tag("nvidia/t1000", "sm_75")
+_register_cuda_tag("nvidia/p620", "sm_61")
+_register_cuda_tag("nvidia/p520", "sm_61")
+
+# =====================================================================
+# Quadro professional mobile GPUs (Pascal / Maxwell)
+# =====================================================================
+_register_cuda_tag("nvidia/quadro-p5200", "sm_61")
+_register_cuda_tag("nvidia/quadro-p4200", "sm_61")
+_register_cuda_tag("nvidia/quadro-p3200", "sm_61")
+_register_cuda_tag("nvidia/quadro-p3000", "sm_61")
+_register_cuda_tag("nvidia/quadro-p500", "sm_61")
+_register_cuda_tag("nvidia/quadro-m5500m", "sm_52")
+_register_cuda_tag("nvidia/quadro-m2200", "sm_52")
+_register_cuda_tag("nvidia/quadro-m1200", "sm_50")
+_register_cuda_tag("nvidia/quadro-m620", "sm_52")
+_register_cuda_tag("nvidia/quadro-m520", "sm_50")
+
+# =====================================================================
+# Quadro professional mobile GPUs (Kepler / Maxwell)
+# =====================================================================
+_register_cuda_tag("nvidia/quadro-k6000m", "sm_30")
+_register_cuda_tag("nvidia/quadro-k5200m", "sm_30")
+_register_cuda_tag("nvidia/quadro-k5100m", "sm_30")
+_register_cuda_tag("nvidia/quadro-m5000m", "sm_50")
+_register_cuda_tag("nvidia/quadro-k500m", "sm_30")
+_register_cuda_tag("nvidia/quadro-k4200m", "sm_30")
+_register_cuda_tag("nvidia/quadro-k4100m", "sm_30")
+_register_cuda_tag("nvidia/quadro-m4000m", "sm_50")
+_register_cuda_tag("nvidia/quadro-k3100m", "sm_30")
+_register_cuda_tag("nvidia/quadro-m3000m", "sm_50")
+_register_cuda_tag("nvidia/quadro-k2200m", "sm_30")
+_register_cuda_tag("nvidia/quadro-k2100m", "sm_30")
+_register_cuda_tag("nvidia/quadro-m2000m", "sm_50")
+_register_cuda_tag("nvidia/quadro-k1100m", "sm_30")
+_register_cuda_tag("nvidia/quadro-m1000m", "sm_50")
+_register_cuda_tag("nvidia/quadro-k620m", "sm_50")
+_register_cuda_tag("nvidia/quadro-k610m", "sm_35")
+_register_cuda_tag("nvidia/quadro-m600m", "sm_50")
+_register_cuda_tag("nvidia/quadro-k510m", "sm_35")
+_register_cuda_tag("nvidia/quadro-m500m", "sm_50")
+
+# =====================================================================
+# NVS cards
+# =====================================================================
+_register_cuda_tag("nvidia/nvidia-nvs-810", "sm_50")
+_register_cuda_tag("nvidia/nvidia-nvs-510", "sm_30")
+_register_cuda_tag("nvidia/nvidia-nvs-315", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/nvidia-nvs-310", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/nvs-5400m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/nvs-5200m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/nvs-4200m", "sm_21", regs=32768)
+
+# =====================================================================
+# GeForce RTX 50-series desktop
+# =====================================================================
+_register_cuda_tag("nvidia/geforce-rtx-5060-ti", "sm_120",
l2_cache_size_bytes=33554432)
+
+# =====================================================================
+# GeForce RTX 40-series desktop
+# =====================================================================
+_register_cuda_tag("nvidia/geforce-rtx-4090", "sm_89",
l2_cache_size_bytes=75497472)
+
+# =====================================================================
+# GeForce RTX 30-series desktop
+# =====================================================================
+_register_cuda_tag("nvidia/geforce-rtx-3090-ti", "sm_86")
+_register_cuda_tag("nvidia/geforce-rtx-3090", "sm_86")
+_register_cuda_tag("nvidia/geforce-rtx-3080-ti", "sm_86")
+_register_cuda_tag("nvidia/geforce-rtx-3080", "sm_86")
+_register_cuda_tag("nvidia/geforce-rtx-3070-ti", "sm_86")
+_register_cuda_tag("nvidia/geforce-rtx-3070", "sm_86")
+_register_cuda_tag("nvidia/geforce-rtx-3060", "sm_86")
+
+# =====================================================================
+# GeForce RTX 20-series / TITAN (Turing)
+# =====================================================================
+_register_cuda_tag("nvidia/nvidia-titan-rtx", "sm_75")
+_register_cuda_tag("nvidia/geforce-rtx-2080-ti", "sm_75")
+_register_cuda_tag("nvidia/geforce-rtx-2080", "sm_75")
+_register_cuda_tag("nvidia/geforce-rtx-2070", "sm_75")
+_register_cuda_tag("nvidia/geforce-rtx-2060", "sm_75")
+
+# =====================================================================
+# GeForce TITAN / GTX 10-series (Pascal)
+# =====================================================================
+_register_cuda_tag("nvidia/nvidia-titan-v", "sm_70")
+_register_cuda_tag("nvidia/nvidia-titan-xp", "sm_61")
+_register_cuda_tag("nvidia/nvidia-titan-x", "sm_61")
+_register_cuda_tag("nvidia/geforce-gtx-1080-ti", "sm_61")
+_register_cuda_tag("nvidia/geforce-gtx-1080", "sm_61")
+_register_cuda_tag("nvidia/geforce-gtx-1070-ti", "sm_61")
+_register_cuda_tag("nvidia/geforce-gtx-1070", "sm_61")
+_register_cuda_tag("nvidia/geforce-gtx-1060", "sm_61")
+_register_cuda_tag("nvidia/geforce-gtx-1050", "sm_61")
+
+# =====================================================================
+# GeForce GTX 900/700 series desktop (Maxwell / Kepler)
+# =====================================================================
+_register_cuda_tag("nvidia/geforce-gtx-titan-x", "sm_52")
+_register_cuda_tag("nvidia/geforce-gtx-titan-z", "sm_35")
+_register_cuda_tag("nvidia/geforce-gtx-titan-black", "sm_35")
+_register_cuda_tag("nvidia/geforce-gtx-titan", "sm_35")
+_register_cuda_tag("nvidia/geforce-gtx-980-ti", "sm_52")
+_register_cuda_tag("nvidia/geforce-gtx-980", "sm_52")
+_register_cuda_tag("nvidia/geforce-gtx-970", "sm_52")
+_register_cuda_tag("nvidia/geforce-gtx-960", "sm_52")
+_register_cuda_tag("nvidia/geforce-gtx-950", "sm_52")
+_register_cuda_tag("nvidia/geforce-gtx-780-ti", "sm_35")
+_register_cuda_tag("nvidia/geforce-gtx-780", "sm_35")
+_register_cuda_tag("nvidia/geforce-gtx-770", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-760", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-750-ti", "sm_50")
+_register_cuda_tag("nvidia/geforce-gtx-750", "sm_50")
+_register_cuda_tag("nvidia/geforce-gtx-690", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-680", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-670", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-660-ti", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-660", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-650-ti-boost", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-650-ti", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-650", "sm_30")
+
+# =====================================================================
+# GeForce GTX 500/400 series desktop (Fermi)
+# =====================================================================
+_register_cuda_tag("nvidia/geforce-gtx-560-ti", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gtx-550-ti", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gtx-460", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gts-450", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gtx-590", "sm_20", regs=32768)
+_register_cuda_tag("nvidia/geforce-gtx-580", "sm_20", regs=32768)
+_register_cuda_tag("nvidia/geforce-gtx-570", "sm_20", regs=32768)
+_register_cuda_tag("nvidia/geforce-gtx-480", "sm_20", regs=32768)
+_register_cuda_tag("nvidia/geforce-gtx-470", "sm_20", regs=32768)
+_register_cuda_tag("nvidia/geforce-gtx-465", "sm_20", regs=32768)
+
+# =====================================================================
+# GeForce GT desktop (Kepler / Fermi)
+# =====================================================================
+_register_cuda_tag("nvidia/geforce-gt-740", "sm_30")
+_register_cuda_tag("nvidia/geforce-gt-730", "sm_35")
+_register_cuda_tag("nvidia/geforce-gt-730-ddr3,128bit", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-720", "sm_35")
+_register_cuda_tag("nvidia/geforce-gt-705", "sm_35")
+_register_cuda_tag("nvidia/geforce-gt-640-gddr5", "sm_35")
+_register_cuda_tag("nvidia/geforce-gt-640-gddr3", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-630", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-620", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-610", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-520", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-440", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-430", "sm_21", regs=32768)
+
+# =====================================================================
+# GeForce notebook GPUs (Maxwell / Kepler)
+# =====================================================================
+_register_cuda_tag("nvidia/geforce-gtx-980m", "sm_52")
+_register_cuda_tag("nvidia/geforce-gtx-970m", "sm_52")
+_register_cuda_tag("nvidia/geforce-gtx-965m", "sm_52")
+_register_cuda_tag("nvidia/geforce-gtx-960m", "sm_50")
+_register_cuda_tag("nvidia/geforce-gtx-950m", "sm_50")
+_register_cuda_tag("nvidia/geforce-940m", "sm_50")
+_register_cuda_tag("nvidia/geforce-930m", "sm_50")
+_register_cuda_tag("nvidia/geforce-920m", "sm_35")
+_register_cuda_tag("nvidia/geforce-910m", "sm_52")
+_register_cuda_tag("nvidia/geforce-gtx-880m", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-870m", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-860m-sm-30", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-860m-sm-50", "sm_50")
+_register_cuda_tag("nvidia/geforce-gtx-850m", "sm_50")
+_register_cuda_tag("nvidia/geforce-840m", "sm_50")
+_register_cuda_tag("nvidia/geforce-830m", "sm_50")
+_register_cuda_tag("nvidia/geforce-820m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-800m", "sm_21", regs=32768)
+
+# =====================================================================
+# GeForce notebook GPUs (Kepler / Fermi, older)
+# =====================================================================
+_register_cuda_tag("nvidia/geforce-gtx-780m", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-770m", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-765m", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-760m", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-680mx", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-680m", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-675mx", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-675m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gtx-670mx", "sm_30")
+_register_cuda_tag("nvidia/geforce-gtx-670m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gtx-660m", "sm_30")
+_register_cuda_tag("nvidia/geforce-gt-755m", "sm_30")
+_register_cuda_tag("nvidia/geforce-gt-750m", "sm_30")
+_register_cuda_tag("nvidia/geforce-gt-650m", "sm_30")
+_register_cuda_tag("nvidia/geforce-gt-745m", "sm_30")
+_register_cuda_tag("nvidia/geforce-gt-645m", "sm_30")
+_register_cuda_tag("nvidia/geforce-gt-740m", "sm_30")
+_register_cuda_tag("nvidia/geforce-gt-730m", "sm_30")
+_register_cuda_tag("nvidia/geforce-gt-640m", "sm_30")
+_register_cuda_tag("nvidia/geforce-gt-640m-le", "sm_30")
+_register_cuda_tag("nvidia/geforce-gt-735m", "sm_30")
+_register_cuda_tag("nvidia/geforce-gt-635m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-630m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-625m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-720m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-620m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-710m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-705m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-610m", "sm_21", regs=32768)
+
+# =====================================================================
+# GeForce notebook GPUs (Fermi, GTX 5xx/4xxM)
+# =====================================================================
+_register_cuda_tag("nvidia/geforce-gtx-580m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gtx-570m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gtx-560m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-555m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-550m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-540m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-525m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-520mx", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-520m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gtx-485m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gtx-470m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gtx-460m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-445m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-435m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-420m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gt-415m", "sm_21", regs=32768)
+_register_cuda_tag("nvidia/geforce-gtx-480m", "sm_20", regs=32768)
+_register_cuda_tag("nvidia/geforce-410m", "sm_21", regs=32768)
+
+# =====================================================================
+# Jetson boards (simple, no host)
+# =====================================================================
+_register_cuda_tag("nvidia/jetson-nano", "sm_53", regs=32768)
+_register_cuda_tag("nvidia/jetson-tx2", "sm_62", regs=32768)
+_register_cuda_tag("nvidia/jetson-tx1", "sm_53", regs=32768)
+_register_cuda_tag("nvidia/tegra-x1", "sm_53", regs=32768)
+
+# =====================================================================
+# Jetson boards (with LLVM host)
+# =====================================================================
+_register_jetson_tag("nvidia/jetson-agx-xavier", "sm_72", "carmel", 8)
+_register_jetson_tag("nvidia/jetson-orin-nano", "sm_87", "carmel", 6)
+_register_jetson_tag("nvidia/jetson-agx-orin-32gb", "sm_87", "cortex-a78", 8)
+_register_jetson_tag("nvidia/jetson-agx-orin-64gb", "sm_87", "cortex-a78", 12)
diff --git a/python/tvm/target/tag_registry/hexagon.py
b/python/tvm/target/tag_registry/hexagon.py
new file mode 100644
index 0000000000..2f25c47b7d
--- /dev/null
+++ b/python/tvm/target/tag_registry/hexagon.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Qualcomm Hexagon target tags."""
+from .registry import register_tag
+
+_ONE_MB = 2**20
+
+_HEXAGON_VERSIONS = {
+ "v65": {"vtcm": _ONE_MB // 4, "mattr": ["+hvxv65", "+hvx-length128b"]},
+ "v66": {"vtcm": _ONE_MB // 4, "mattr": ["+hvxv66", "+hvx-length128b"]},
+ "v68": {
+ "vtcm": 4 * _ONE_MB,
+ "mattr": ["+hvxv68", "+hvx-length128b", "+hvx-qfloat", "-hvx-ieee-fp"],
+ "llvm-options": ["-force-hvx-float"],
+ },
+ "v69": {
+ "vtcm": 8 * _ONE_MB,
+ "mattr": ["+hvxv69", "+hvx-length128b", "+hvx-qfloat", "-hvx-ieee-fp"],
+ },
+ "v73": {
+ "vtcm": 8 * _ONE_MB,
+ "mattr": ["+hvxv73", "+hvx-length128b", "+hvx-qfloat", "-hvx-ieee-fp"],
+ },
+ "v75": {
+ "vtcm": 8 * _ONE_MB,
+ "mattr": ["+hvxv75", "+hvx-length128b", "+hvx-qfloat", "-hvx-ieee-fp"],
+ },
+}
+
+for _ver, _info in _HEXAGON_VERSIONS.items():
+ _config = {
+ "kind": "hexagon",
+ "mtriple": "hexagon",
+ "mcpu": "hexagon" + _ver,
+ "mattr": _info["mattr"],
+ "num-cores": 4,
+ "vtcm-capacity": _info["vtcm"],
+ }
+ if "llvm-options" in _info:
+ _config["llvm-options"] = _info["llvm-options"]
+ register_tag("qcom/hexagon-" + _ver, _config)
diff --git a/python/tvm/target/compilation_config.py
b/python/tvm/target/tag_registry/metal.py
similarity index 53%
rename from python/tvm/target/compilation_config.py
rename to python/tvm/target/tag_registry/metal.py
index 116f1dd8e9..248a5c8001 100644
--- a/python/tvm/target/compilation_config.py
+++ b/python/tvm/target/tag_registry/metal.py
@@ -14,14 +14,27 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Python bindings for creating CompilationConfigs."""
-import tvm
-from . import _ffi_api
+"""Apple Metal GPU target tags."""
+from .registry import register_tag
-def make_compilation_config(ctxt, target, target_host=None):
- """Returns a CompilationConfig appropriate for target and target_host,
using the same
- representation conventions as for the standard build interfaces. Intended
only for unit
- testing."""
- raw_targets = tvm.target.Target.canon_multi_target_and_host(target,
target_host)
- return _ffi_api.MakeCompilationConfig(ctxt, raw_targets)
+def _register_metal_tag(name, max_threads, shared_mem, warp_size):
+ register_tag(
+ name,
+ {
+ "kind": "metal",
+ "max_threads_per_block": max_threads,
+ "max_shared_memory_per_block": shared_mem,
+ "thread_warp_size": warp_size,
+ "host": {
+ "kind": "llvm",
+ "mtriple": "arm64-apple-macos",
+ "mcpu": "apple-m4",
+ },
+ },
+ )
+
+
+_register_metal_tag("apple/m1-gpu", 1024, 32768, 32)
+_register_metal_tag("apple/m1-gpu-restricted", 256, 32768, 32)
+_register_metal_tag("apple/m2-gpu", 1024, 32768, 32)
diff --git a/python/tvm/target/tag.py
b/python/tvm/target/tag_registry/registry.py
similarity index 84%
copy from python/tvm/target/tag.py
copy to python/tvm/target/tag_registry/registry.py
index 0cb2b97e15..99fd08f1c3 100644
--- a/python/tvm/target/tag.py
+++ b/python/tvm/target/tag_registry/registry.py
@@ -14,10 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Target tags"""
+"""Target tag registry functions."""
from typing import Any, Dict, Optional
-from . import _ffi_api
-from .target import Target
+from .. import _ffi_api
+from ..target import Target
def list_tags() -> Optional[Dict[str, Target]]:
@@ -65,17 +65,3 @@ def register_tag(name: str, config: Dict[str, Any],
override: bool = False) -> O
if hasattr(_ffi_api, "TargetTagAddTag"):
return _ffi_api.TargetTagAddTag(name, config, override)
return None
-
-
-# We purposely maintain all tags in the C++ side to support pure C++ use cases,
-# and the Python API is only used for fast prototyping.
-register_tag(
- "nvidia/gtx1080ti",
- config={
- "kind": "cuda",
- "arch": "sm_61",
- },
-)
-
-# To check the correctness of all registered tags, the call is made in library
loading time.
-list_tags()
diff --git a/python/tvm/target/tag_registry/riscv_cpu.py
b/python/tvm/target/tag_registry/riscv_cpu.py
new file mode 100644
index 0000000000..a05096aa9a
--- /dev/null
+++ b/python/tvm/target/tag_registry/riscv_cpu.py
@@ -0,0 +1,72 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""RISC-V CPU target tags (from old riscv_cpu() translation table)."""
+from .registry import register_tag
+
+
+def _register_riscv_tag(name, config):
+ base = {"kind": "llvm", "keys": ["arm_cpu", "cpu"], "device": "arm_cpu"}
+ base.update(config)
+ register_tag(name, base)
+
+
+_register_riscv_tag(
+ "riscv/sifive-e31",
+ {
+ "model": "sifive-e31",
+ "mtriple": "riscv32-unknown-linux-gnu",
+ "mcpu": "sifive-e31",
+ "mabi": "ilp32",
+ },
+)
+_register_riscv_tag(
+ "riscv/sifive-e76",
+ {
+ "model": "sifive-e76",
+ "mtriple": "riscv32-unknown-linux-gnu",
+ "mcpu": "sifive-e76",
+ "mabi": "ilp32",
+ },
+)
+_register_riscv_tag(
+ "riscv/sifive-u54",
+ {
+ "model": "sifive-u54",
+ "mtriple": "riscv64-unknown-linux-gnu",
+ "mcpu": "sifive-u54",
+ "mabi": "lp64d",
+ },
+)
+_register_riscv_tag(
+ "riscv/sifive-u74",
+ {
+ "model": "sifive-u74",
+ "mtriple": "riscv64-unknown-linux-gnu",
+ "mcpu": "sifive-u74",
+ "mabi": "lp64d",
+ },
+)
+_register_riscv_tag(
+ "riscv/licheepi3a",
+ {
+ "num-cores": 8,
+ "mtriple": "riscv64-unknown-linux-gnu",
+ "mcpu": "spacemit-x60",
+ "mfloat-abi": "hard",
+ "mabi": "lp64d",
+ },
+)
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index 6cffed9e68..b64c2de961 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -15,15 +15,13 @@
# specific language governing permissions and limitations
# under the License.
"""Target data structure."""
-import re
-import warnings
from typing import Union
import tvm_ffi
from tvm.runtime import Device
from tvm.runtime import Object, convert
from tvm.runtime.container import String
-from tvm.ir.container import Map, Array
+from tvm.ir.container import Map
from . import _ffi_api
@@ -55,22 +53,35 @@ class TargetFeatures:
class Target(Object):
"""Target device information, use through TVM API.
- Note
- ----
- You can create target using the constructor or the following functions
+ Targets can be constructed from:
- - :py:func:`tvm.target.arm_cpu` create arm_cpu target
- - :py:func:`tvm.target.cuda` create CUDA target
- - :py:func:`tvm.target.rocm` create ROCM target
- - :py:func:`tvm.target.mali` create Mali target
- - :py:func:`tvm.target.intel_graphics` create Intel Graphics target
+ - A JSON config dictionary: ``Target({"kind": "cuda", "arch": "sm_80"})``
+ - A tag name: ``Target("nvidia/nvidia-a100")``
+ - A tag with overrides: ``Target({"tag": "nvidia/nvidia-a100",
"l2_cache_size_bytes": 12345})``
+ - A kind name: ``Target("cuda")``
+
+ Use ``target.attrs["key"]`` to access target attributes.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ # From a tag
+ target = Target("nvidia/nvidia-a100")
+
+ # From a tag with attribute overrides
+ target = Target({"tag": "qcom/hexagon-v68", "vtcm-capacity": 70000})
+
+ # From a config dictionary
+ target = Target({"kind": "cuda", "arch": "sm_80"})
"""
def __init__(self, target, host=None):
"""Construct a TVM target object from
1) Raw target string
2) Target config dict
- 3) Target tag
+ 3) Target tag string
+ 4) Tag with overrides dict
Parameters
----------
@@ -80,7 +91,12 @@ class Target(Object):
When using a dictionary or json string to configure target, the
possible values are:
- kind : str (required)
+ tag : str (optional)
+ A registered tag name (e.g. ``"nvidia/nvidia-a100"``).
+ When ``tag`` is present, the tag's base config is loaded and
+ any additional fields in the dict override the base values.
+ The ``kind`` field is not needed when ``tag`` is specified.
+ kind : str (required unless tag is specified)
Which codegen path to use, for example 'llvm' or 'cuda'.
keys : List of str (optional)
A set of strategies that can be dispatched to. When using
@@ -179,91 +195,10 @@ class Target(Object):
"""
return _ffi_api.TargetCurrent(allow_none)
- @property
- def arch(self):
- """Returns the cuda arch from the target if it exists."""
- return str(self.attrs.get("arch", ""))
-
- @property
- def max_num_threads(self):
- """Returns the max_num_threads from the target if it exists."""
- return int(self.attrs["max_num_threads"])
-
- @property
- def max_block_size_x(self):
- """Returns the max block size in x-dimension from the target if it
exists."""
- return int(self.attrs["max_block_size_x"])
-
- @property
- def max_block_size_y(self):
- """Returns the max block size in y-dimension from the target if it
exists."""
- return int(self.attrs["max_block_size_y"])
-
- @property
- def thread_warp_size(self):
- """Returns the thread_warp_size from the target if it exists."""
- return int(self.attrs["thread_warp_size"])
-
- @property
- def max_shared_memory_per_block(self):
- return int(self.attrs["max_shared_memory_per_block"])
-
- @property
- def max_function_args(self):
- return int(self.attrs.get("max_function_args", 0))
-
- @property
- def vtcm_capacity(self):
- return int(self.attrs.get("vtcm-capacity", 0))
-
- @property
- def device_name(self):
- return str(self.attrs.get("device", ""))
-
- @property
- def model(self):
- """Returns model from the target if it exists."""
- return str(self.attrs.get("model", "unknown"))
-
- @property
- def mcpu(self):
- """Returns the mcpu from the target if it exists."""
- return str(self.attrs.get("mcpu", ""))
-
- @property
- def mattr(self):
- """Returns the mattr from the target if it exists."""
- return list(self.attrs.get("mattr", []))
-
- @property
- def supports_integer_dot_product(self):
- if self.attrs.get("supports_integer_dot_product", []):
- return bool(self.attrs["supports_integer_dot_product"])
- if self.kind.name == "cuda":
- sm_version = int(self.arch.split("_")[1])
- if sm_version >= 61:
- return True
- return False
-
- @property
- def libs(self):
- return list(self.attrs.get("libs", []))
-
- @property
- def supports_cooperative_matrix(self):
- if self.attrs.get("supports_cooperative_matrix", []):
- return bool(self.attrs["supports_cooperative_matrix"])
- else:
- return False
-
@property
def features(self):
return TargetFeatures(self)
- @property
- def l2_cache_size_bytes(self):
- return int(self.attrs.get("l2_cache_size_bytes", 0))
-
def get_kind_attr(self, attr_name):
"""Get additional attribute about the target kind.
@@ -288,115 +223,6 @@ class Target(Object):
"""Returns the list of available target names."""
return list(_ffi_api.ListTargetKinds())
- @staticmethod
- def canon_target(target):
- """Given a single target-like object, returns the TVM Target object
representing it.
- Can convert from:
- - None (to None).
- - An existing TVM Target object.
- - A string, eg "cuda"
- - A Python dictionary, eg {"kind": "cuda", "arch": "sm_80" }
- """
- if target is None:
- return None
- if isinstance(target, Target):
- return target
- return Target(target)
-
- @staticmethod
- def canon_target_and_host(target, target_host=None):
- """Returns a TVM Target capturing target and target_host. Also returns
the host in
- canonical form. The given target can be in any form recognized by
- Target.canon_target. If given, target_host can be in any form
recognized by
- Target.canon_target. If target_host is given it will be set as the
'host' in the
- result Target object (and a warning given).
-
- Note that this method does not support heterogeneous compilation
targets.
- """
- target = Target.canon_target(target)
- if target is None:
- assert target_host is None, "Target host is not empty when target
is empty."
- return target, target_host
- if target.host is None and target_host is not None:
- warnings.warn(
- "target_host parameter is going to be deprecated. "
- "Please pass in tvm.target.Target(target, host=target_host)
instead."
- )
- target_host = Target.canon_target(target_host)
- target = target.with_host(target_host)
- if target is not None:
- # In case the target already had a host, extract it here.
- target_host = target.host
- return target, target_host
-
- @staticmethod
- def canon_multi_target(multi_targets):
- """Given a single target-like object, or a collection-like object of
target-like objects,
- returns a TVM Array of TVM Target objects representing then. Can
convert from:
- - None (to None).
- - A single target-like object in a form recognized by canon_target.
- - A Python list or TVM Array of target-like objects in a form
recognized by
- canon_target.
- - A Python dict or TVM Map from TVM IntImm objects representing device
types to
- a target-like object in a form recognized by canon_target. (This is a
legacy
- method to represent heterogeneous targets. The keys are ignored.)
- """
- if multi_targets is None:
- return None
- if isinstance(multi_targets, (dict, Map)) and "kind" not in
multi_targets:
- # Convert legacy heterogeneous map representation to ordinary list
of targets.
- return Target.canon_multi_target(list(multi_targets.values()))
- if isinstance(multi_targets, (list, Array)):
- # Multiple Target results.
- return convert([Target.canon_target(tgt) for tgt in multi_targets])
- # Single Target result.
- return convert([Target.canon_target(multi_targets)])
-
- @staticmethod
- def canon_multi_target_and_host(target, target_host=None):
- """Returns a TVM Array<Target> capturing target and target_host. The
given target can be in
- any form recognized by Target.canon_multi_target. If given,
target_host can be in
- any form recognized by Target.canon_target. If target_host is given it
will be set
- as the 'host' in each result Target object (and a warning given).
- """
- # Convert target to Array<Target>, but not yet accounting for any host.
- raw_targets = Target.canon_multi_target(target)
- assert raw_targets is not None and len(raw_targets) > 0
- # Convert host to Target, if given.
- if raw_targets[0].host is None and target_host is not None:
- warnings.warn(
- "target_host parameter is going to be deprecated. "
- "Please pass in tvm.target.Target(target, host=target_host)
instead."
- )
- # Make sure the (canonical) host is captured in all the
(canonical) targets.
- target_host = Target.canon_target(target_host)
- raw_targets = convert([tgt.with_host(target_host) for tgt in
raw_targets])
- return raw_targets
-
- @staticmethod
- def canon_target_map_and_host(target_map, target_host=None):
- """Returns target_map as a map from TVM Target's in canonical form to
IRModules. The keys
- of the input target_map can be in any form recognized by
Target.canon_target.
- Similarly, if given, target_host can be in any form recognized by
- Target.canon_target. The final target_map keys will capture the
target_host in
- canonical form. Also returns the target_host in canonical form."""
- new_target_map = {}
- canonical_target_host = None
- for tgt, mod in target_map.items():
- tgt = Target.canon_target(tgt)
- assert tgt is not None
- if canonical_target_host is None:
- if tgt.host is not None:
- canonical_target_host = tgt.host
- elif target_host is not None:
- # No deprecation warning in this case since host may have
been manufactured
- # behind the scenes in build_module.py build.
- canonical_target_host = Target.canon_target(target_host)
- if tgt.host is None and canonical_target_host is not None:
- tgt = tgt.with_host(canonical_target_host)
- new_target_map[tgt] = mod
- return new_target_map, canonical_target_host
-
@staticmethod
def target_or_current(target):
"""Returns target, or the current target in the environment if target
is None"""
@@ -405,473 +231,3 @@ class Target(Object):
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
return target
-
-
-def cuda(model="unknown", arch=None, **kwargs):
- """Returns a cuda target.
-
- Parameters
- ----------
- model: str
- The model of cuda device (e.g. 1080ti)
- arch: str
- The cuda architecture (e.g. sm_61)
- kwargs: dict
- Additional target attributes
- """
- config = {"kind": "cuda", "model": model}
- if arch:
- config["arch"] = arch
- else:
- warnings.warn("Try specifying cuda arch by adding arch='sm_xx' to your
target.")
- config.update(kwargs)
- return Target(config)
-
-
-def rocm(model="unknown", **kwargs):
- """Returns a ROCM target.
-
- Parameters
- ----------
- model: str
- The model of this device
- kwargs: dict
- Additional target attributes
- """
- config = {"kind": "rocm", "model": model}
- config.update(kwargs)
- return Target(config)
-
-
-def mali(model="unknown", **kwargs):
- """Returns a ARM Mali GPU target.
-
- Parameters
- ----------
- model: str
- The model of this device
- kwargs: dict
- Additional target attributes
- """
- config = {"kind": "opencl", "device": "mali", "model": model}
- config.update(kwargs)
- return Target(config)
-
-
-def intel_graphics(model="unknown", **kwargs):
- """Returns an Intel Graphics target.
-
- Parameters
- ----------
- model: str
- The model of this device
- kwargs: dict
- Additional target attributes
- """
- config = {
- "kind": "opencl",
- "device": "intel_graphics",
- "model": model,
- "thread_warp_size": 16,
- }
- config.update(kwargs)
- return Target(config)
-
-
-MICRO_SUPPORTED_MODELS = {
- "host": {},
- "atsamd51": {"mcpu": "cortex-m4"},
- "cxd5602gg": {"mcpu": "cortex-m4"},
- "esp32": {},
- "imxrt10xx": {"mcpu": "cortex-m7"},
- "mps2_an521": {"mcpu": "cortex-m33"},
- "mps3_an547": {"mcpu": "cortex-m55"},
- "nrf52840": {"mcpu": "cortex-m4+nodsp"},
- "nrf5340dk": {"mcpu": "cortex-m33"},
- "rp2040": {"mcpu": "cortex-m0"},
- "sam3x8e": {"mcpu": "cortex-m3"},
- "stm32f746xx": {"mcpu": "cortex-m7", "march": "armv7e-m"},
- "stm32h7xx": {"mcpu": "cortex-m7"},
- "stm32l4r5zi": {"mcpu": "cortex-m4"},
- "stm32u5xx": {"mcpu": "cortex-m33"},
- "zynq_mp_r5": {"mcpu": "cortex-r5"},
-}
-
-
-def _parse_cli_opts_to_dict(opts):
- """Convert a list of CLI-style options (e.g. ['-mcpu=cortex-a72']) to a
dict."""
- result = {}
- for opt in opts:
- opt = opt.lstrip("-")
- if "=" in opt:
- key, val = opt.split("=", 1)
- # Handle comma-separated values as lists
- if "," in val:
- val = val.split(",")
- result[key] = val
- else:
- result[opt] = True
- return result
-
-
-def arm_cpu(model="unknown", options=None):
- """Returns a ARM CPU target.
- This function will also download pre-tuned op parameters when there is
none.
-
- Parameters
- ----------
- model: str
- SoC name or phone name of the arm board.
- options : str or list of str
- Additional options
- """
- trans_table = {
- "pixel2": {"model": "snapdragon835", "mtriple": "arm64-linux-android",
"mattr": ["+neon"]},
- "mate10": {"model": "kirin970", "mtriple": "arm64-linux-android",
"mattr": ["+neon"]},
- "mate10pro": {"model": "kirin970", "mtriple": "arm64-linux-android",
"mattr": ["+neon"]},
- "p20": {"model": "kirin970", "mtriple": "arm64-linux-android",
"mattr": ["+neon"]},
- "p20pro": {"model": "kirin970", "mtriple": "arm64-linux-android",
"mattr": ["+neon"]},
- "rasp3b": {
- "model": "bcm2837",
- "mtriple": "armv7l-linux-gnueabihf",
- "mattr": ["+neon"],
- },
- "rasp4b": {
- "model": "bcm2711",
- "mtriple": "armv8l-linux-gnueabihf",
- "mattr": ["+neon"],
- "mcpu": "cortex-a72",
- },
- "rasp4b64": {
- "model": "bcm2711",
- "mtriple": "aarch64-linux-gnu",
- "mattr": ["+neon"],
- "mcpu": "cortex-a72",
- },
- "rk3399": {"model": "rk3399", "mtriple": "aarch64-linux-gnu", "mattr":
["+neon"]},
- "pynq": {"model": "pynq", "mtriple": "armv7a-linux-eabi", "mattr":
["+neon"]},
- "ultra96": {"model": "ultra96", "mtriple": "aarch64-linux-gnu",
"mattr": ["+neon"]},
- "beagleai": {
- "model": "beagleai",
- "mtriple": "armv7a-linux-gnueabihf",
- "mattr": ["+neon", "+vfp4", "+thumb2"],
- "mcpu": "cortex-a15",
- },
- "stm32mp1": {
- "model": "stm32mp1",
- "mtriple": "armv7a-linux-gnueabihf",
- "mattr": ["+neon", "+vfp4", "+thumb2"],
- "mcpu": "cortex-a7",
- },
- "thunderx": {
- "model": "thunderx",
- "mtriple": "aarch64-linux-gnu",
- "mattr": ["+neon", "+crc", "+lse"],
- "mcpu": "thunderxt88",
- },
- }
- pre_defined = trans_table.get(model, {"model": model})
-
- config = {
- "kind": "llvm",
- "keys": ["arm_cpu", "cpu"],
- "device": "arm_cpu",
- }
- config.update(pre_defined)
- if options:
- if isinstance(options, str):
- options = options.split()
- config.update(_parse_cli_opts_to_dict(options))
- return Target(config)
-
-
-def rasp(options=None):
- """Return a Raspberry 3b target.
-
- Parameters
- ----------
- options : str or list of str
- Additional options
- """
- warnings.warn(
- "tvm.target.rasp() is going to be deprecated. " 'Please use
tvm.target.arm_cpu("rasp3b")'
- )
- return arm_cpu("rasp3b", options)
-
-
-def bifrost(model="unknown", **kwargs):
- """Return an ARM Mali GPU target (Bifrost architecture).
-
- Parameters
- ----------
- model: str
- The model of this device
- kwargs: dict
- Additional target attributes
- """
- config = {"kind": "opencl", "device": "bifrost", "model": model}
- config.update(kwargs)
- return Target(config)
-
-
-def riscv_cpu(model="sifive-u54", options=None):
- """Returns a RISC-V CPU target.
- Default: sifive-u54 rv64gc
-
- Parameters
- ----------
- model: str
- CPU name.
- options : str or list of str
- Additional options
- """
- trans_table = {
- "sifive-e31": {
- "model": "sifive-e31",
- "mtriple": "riscv32-unknown-linux-gnu",
- "mcpu": "sifive-e31",
- "mabi": "ilp32",
- },
- "sifive-e76": {
- "model": "sifive-e76",
- "mtriple": "riscv32-unknown-linux-gnu",
- "mcpu": "sifive-e76",
- "mabi": "ilp32",
- },
- "sifive-u54": {
- "model": "sifive-u54",
- "mtriple": "riscv64-unknown-linux-gnu",
- "mcpu": "sifive-u54",
- "mabi": "lp64d",
- },
- "sifive-u74": {
- "model": "sifive-u74",
- "mtriple": "riscv64-unknown-linux-gnu",
- "mcpu": "sifive-u74",
- "mabi": "lp64d",
- },
- "licheepi3a": {
- "num-cores": 8,
- "mtriple": "riscv64-unknown-linux-gnu",
- "mcpu": "spacemit-x60",
- "mfloat-abi": "hard",
- "mabi": "lp64d",
- },
- }
- pre_defined = trans_table.get(model, {"model": model})
-
- config = {
- "kind": "llvm",
- "keys": ["arm_cpu", "cpu"],
- "device": "arm_cpu",
- }
- config.update(pre_defined)
- if options:
- if isinstance(options, str):
- options = options.split()
- config.update(_parse_cli_opts_to_dict(options))
- return Target(config)
-
-
-def hexagon(cpu_ver="v68", **kwargs):
- """Returns a Hexagon target.
-
- Parameters
- ----------
- cpu_ver : str (default: "v68")
- CPU version used for code generation. Not all allowed cpu str
- will be valid, LLVM will throw an error.
-
- Recognized keyword parameters
- -----------------------------
- hvx : int (default: 128)
- Size of HVX vector in bytes. Value of 0 disables HVX codegen.
- llvm_options : str or list of str (default: None)
- User defined compiler arguments.
- use_qfloat : bool (default: True for cpu_ver >= v68, False otherwise)
- Whether to use QFloat HVX instructions.
- use_ieee_fp : bool (default: False)
- Whether to use IEEE HVX instructions
- num_cores : int (default: 4)
- The number of HVX threads. This attribute is required by meta
scheduler.
- vtcm_capacity: int (default: 0)
- Hexagon VTCM capacity limitation. If the value is 0, the capacity is
treated as unbounded.
-
- Note: Floating point support in HVX requires LLVM 14+.
- """
-
- def get_arch_version(cpu_ver):
- m = re.match(r"v([0-9]+).*", cpu_ver)
- assert m
- return int(m.group(1))
-
- # Check for valid codegen cpu
- valid_hex = ["v65", "v66", "v67", "v67t", "v68", "v69", "v71", "v73",
"v75"]
- try:
- cpu_ver = cpu_ver[cpu_ver.index("v") :].lower()
- assert cpu_ver in valid_hex
- except:
- msg = "{} is not a valid Hexagon version\nvalid versions include {}"
- raise ValueError(msg.format(cpu_ver, valid_hex)) from None
-
- def get_vtcm_capacity(cpu_ver):
- one_mb = 2**20
- default_vtcm_sizes = {
- "v65": one_mb // 4,
- "v66": one_mb // 4,
- "v68": 4 * one_mb,
- "v69": 8 * one_mb,
- "v73": 8 * one_mb,
- "v75": 8 * one_mb,
- }
- return default_vtcm_sizes.get(cpu_ver, 0)
-
- arch_version = get_arch_version(cpu_ver)
- local_config = {
- "hvx": 128,
- "llvm_options": None,
- "use_qfloat": arch_version >= 68,
- "use_ieee_fp": False,
- "vtcm_capacity": get_vtcm_capacity(cpu_ver),
- }
- local_config.update(kwargs)
-
- # Warn about obsolete parameter names.
- if local_config.get("sim_args") or local_config.get("sim_options"):
- msg = (
- "Setting simulator options in target is deprecated, set
environment variable "
- "HEXAGON_SIM_ARGS instead"
- )
- warnings.warn(msg, stacklevel=2)
- if local_config.get("llvm_args"):
- msg = "The keyword parameter 'llvm_args' is deprecated, use
'llvm_options' instead"
- warnings.warn(msg, stacklevel=2)
- local_config.update({"llvm_options": local_config["llvm_args"]})
-
- # Build mattr list from config
- features_map = {
- "use_qfloat": "hvx-qfloat",
- "use_ieee_fp": "hvx-ieee-fp",
- }
- mattr = []
- if local_config["hvx"] > 0:
- valid_hvx = [0, 64, 128]
- if local_config["hvx"] not in valid_hvx:
- raise ValueError("Invalid hvx value, should be one of " +
str(valid_hvx))
- mattr += ["+hvx" + cpu_ver, "+hvx-length" + str(local_config["hvx"]) +
"b"]
- else:
- mattr += ["-hvx"]
- if arch_version >= 68:
- for f, feat_name in features_map.items():
- mattr.append(("-", "+")[local_config[f]] + feat_name)
-
- # Build llvm-options list
- llvm_options_list = []
- llvm_options = local_config["llvm_options"]
- if arch_version == 68:
- if not llvm_options:
- llvm_options = ""
- llvm_options += " -force-hvx-float"
- if llvm_options and len(llvm_options.strip()) > 0:
- llvm_options_list = [s.replace("=", "@") for s in llvm_options.split()]
-
- num_cores = local_config["num_cores"] if "num_cores" in kwargs else 4
-
- target_config = {
- "kind": "hexagon",
- "mtriple": "hexagon",
- "mcpu": "hexagon" + cpu_ver,
- "mattr": mattr,
- "num-cores": num_cores,
- "vtcm-capacity": local_config["vtcm_capacity"],
- }
- if llvm_options_list:
- target_config["llvm-options"] = llvm_options_list
-
- return Target(target_config)
-
-
-STM32_SUPPORTED_SERIES = {
- # High-Performance
- "stm32H7xx": {"mcpu": "cortex-m7", "march": "armv7e-m"},
- "stm32F7xx": {"mcpu": "cortex-m7"},
- "stm32F4xx": {"mcpu": "cortex-m4"},
- "stm32F2xx": {"mcpu": "cortex-m3"},
- # Mainstream
- "stm32G0xx": {"mcpu": "cortex-m0+"},
- "stm32F0xx": {"mcpu": "cortex-m0"},
- "stm32F1xx": {"mcpu": "cortex-m3"},
- "stm32G4xx": {"mcpu": "cortex-m4"},
- "stm32F3xx": {"mcpu": "cortex-m4"},
- # Low-power
- "stm32U5xx": {"mcpu": "cortex-m33"},
- "stm32L5xx": {"mcpu": "cortex-m33"},
- "stm32L4xx": {"mcpu": "cortex-m4"},
- "stm32L1xx": {"mcpu": "cortex-m3"},
- "stm32L0xx": {"mcpu": "cortex-m0+"},
-}
-
-
-def stm32(series="unknown", options=None):
- """Returns a STM32 target.
-
- Parameters
- ----------
- series: str
- Series name of a STM32 board series, eg. stm32H7xx or stm32F4xx
- options : str or list of str
- Additional options
- """
- if series not in STM32_SUPPORTED_SERIES:
- raise ValueError(f"Series {series} is not supported by
tvm.target.stm32.")
- config = {
- "kind": "c",
- "keys": ["arm_cpu", "cpu"],
- "device": "arm_cpu",
- }
- config.update(STM32_SUPPORTED_SERIES[series])
- if options:
- if isinstance(options, str):
- options = options.split()
- config.update(_parse_cli_opts_to_dict(options))
- return Target(config)
-
-
-def adreno(model="unknown", options=None, cfg=None, backend="opencl"):
- """Returns a Qualcomm GPU target.
- Parameters
- ----------
- model: str
- The model of this device
- options : str or list of str
- Additional options
- cfg : str
- Additional hints for target pipeline behavior
- backend : str
- Backend API, can be "opencl" or "vulkan"
- """
-
- if backend not in ["opencl", "vulkan"]:
- raise ValueError(f"Unsupported API: {backend}. Must be 'opencl' or
'vulkan'.")
-
- keys = ["adreno", backend, "gpu"]
- if cfg:
- keys.append(cfg)
-
- config = {
- "kind": backend,
- "device": "adreno",
- "keys": keys,
- "model": model,
- }
- if options:
- if isinstance(options, str):
- options = options.split()
- config.update(_parse_cli_opts_to_dict(options))
- return Target(config)
-
-
-def create(target):
- """Deprecated. Use the constructor of :py:mod:`tvm.target.Target`
directly."""
- warnings.warn("tvm.target.create() is being deprecated. Please use
tvm.target.Target() instead")
- return Target(target)
diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py
index 5df2663fc2..6b9946fb02 100644
--- a/python/tvm/tir/build.py
+++ b/python/tvm/tir/build.py
@@ -187,7 +187,7 @@ def build(
if target_to_bind is None:
target_to_bind = "llvm"
assert target_to_bind is not None
- target_to_bind = Target.canon_target(target_to_bind)
+ target_to_bind = Target(target_to_bind)
# Step 1: Determine the target to search for tir pipeline
target = Target.current() if target is None else target
@@ -198,7 +198,7 @@ def build(
target = f_target
break
if target is not None:
- target = Target.canon_target(target)
+ target = Target(target)
# Step 2: Determine the host target
target_host = "llvm" if tvm.runtime.enabled("llvm") else "c"
@@ -209,7 +209,7 @@ def build(
tvm.device(target.kind.name, 0).dlpack_device_type() ==
tvm.cpu(0).dlpack_device_type()
):
target_host = target
- target_host = Target.canon_target(target_host)
+ target_host = Target(target_host)
target_to_bind = target_to_bind.with_host(target_host)
# Step 3: Bind the target to the input module
diff --git a/python/tvm/topi/gpu/scan.py b/python/tvm/topi/gpu/scan.py
index 643d69d83c..c1319c5ba0 100644
--- a/python/tvm/topi/gpu/scan.py
+++ b/python/tvm/topi/gpu/scan.py
@@ -87,7 +87,7 @@ def exclusive_scan_ir(data, output, reduction=None,
binop=tvm.tir.generic.add, i
if reduction is not None:
reduction = T.buffer_proxy(reduction)
- max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+ max_threads =
int(tvm.target.Target.current(allow_none=False).attrs["max_num_threads"])
with T.If(scan_axis_size == 0):
with T.Then():
@@ -270,7 +270,7 @@ def get_reduction_from_exclusive_scan(data, ex_scan_output,
binop=tvm.tir.generi
batch_size = cast(prod(data_buf.shape[:-1]), "int32")
scan_axis_size = cast(data_buf.shape[-1], "int32")
- max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+ max_threads =
int(tvm.target.Target.current(allow_none=False).attrs["max_num_threads"])
with IRBuilder() as ib:
data = T.buffer_proxy(data_buf)
diff --git a/python/tvm/topi/gpu/sort.py b/python/tvm/topi/gpu/sort.py
index 5a25007f12..41e0c9cda7 100644
--- a/python/tvm/topi/gpu/sort.py
+++ b/python/tvm/topi/gpu/sort.py
@@ -47,7 +47,7 @@ def _sort_init(shape, axis, keys_in, keys_out,
values_out=None, value_init_func=
axis_mul_after *= value
# Set up threading
- max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+ max_threads =
int(tvm.target.Target.current(allow_none=False).attrs["max_num_threads"])
nthread_tx = max_threads
nthread_bx = ceil_div(shape[axis], max_threads)
nthread_by = axis_mul_before * axis_mul_after
@@ -259,7 +259,7 @@ def _sort_common(
## the merge into more blocks
target = tvm.target.Target.current(allow_none=False)
- max_threads = int(target.max_num_threads)
+ max_threads = int(target.attrs["max_num_threads"])
is_webgpu = "webgpu" in str(target)
target_dtype = "int32" if is_webgpu else "int64"
nthread_by = axis_mul_before * axis_mul_after
@@ -1199,7 +1199,7 @@ def searchsorted(sorted_sequence, values, right=False,
out_dtype="int64"):
values_ptr = T.buffer_proxy(values_buf)
indices_ptr = T.buffer_proxy(indices_buf)
- max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+ max_threads =
int(tvm.target.Target.current(allow_none=False).attrs["max_num_threads"])
nthread_tx = max_threads
nthread_bx = ceil_div(num_search, nthread_tx)
tx = te.thread_axis("threadIdx.x")
diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py
index 61b39aad91..c47c3ca9e6 100644
--- a/python/tvm/topi/math.py
+++ b/python/tvm/topi/math.py
@@ -849,7 +849,11 @@ def ceil_log2(x):
return cast(res, x.dtype)
return res
- if "adreno" in target.device_name or target.kind.name in ["metal", "rocm",
"webgpu"]:
+ if "adreno" in str(target.attrs.get("device", "")) or target.kind.name in [
+ "metal",
+ "rocm",
+ "webgpu",
+ ]:
return cast(tvm.tir.ceil(tvm.tir.log2(cast(x, "float32"))), x.dtype)
return cast(tvm.tir.ceil(tvm.tir.log2(cast(x, "float64"))), x.dtype)
diff --git a/src/target/tag.cc b/src/target/tag.cc
index dfe179f7ac..d8ba94e6c6 100644
--- a/src/target/tag.cc
+++ b/src/target/tag.cc
@@ -57,6 +57,15 @@ ffi::Optional<Target> TargetTag::Get(const ffi::String&
target_tag_name) {
return Target(reg->tag_->config);
}
+ffi::Optional<ffi::Map<ffi::String, ffi::Any>> TargetTag::GetConfig(
+ const ffi::String& target_tag_name) {
+ const TargetTagRegEntry* reg =
TargetTagRegistry::Global()->Get(target_tag_name);
+ if (reg == nullptr) {
+ return std::nullopt;
+ }
+ return reg->tag_->config;
+}
+
ffi::Map<ffi::String, Target> TargetTag::ListTags() {
ffi::Map<ffi::String, Target> result;
for (const ffi::String& tag : TargetTagRegistry::Global()->ListAllNames()) {
@@ -73,404 +82,4 @@ Target TargetTag::AddTag(ffi::String name,
ffi::Map<ffi::String, ffi::Any> confi
return Target(config);
}
-/********** Register Target tags **********/
-
-#if TVM_LLVM_HAS_AARCH64_TARGET
-TVM_REGISTER_TARGET_TAG("raspberry-pi/4b-aarch64")
- .set_config({{"kind", ffi::String("llvm")},
- {"mtriple", ffi::String("aarch64-linux-gnu")},
- {"mcpu", ffi::String("cortex-a72")},
- {"mattr", ffi::Array<ffi::String>{"+neon"}},
- {"num-cores", 4},
- {"host",
- ffi::Map<ffi::String, ffi::Any>{{"kind",
ffi::String("llvm")},
- {"mtriple",
ffi::String("aarch64-linux-gnu")},
- {"mcpu",
ffi::String("cortex-a72")},
- {"mattr",
ffi::Array<ffi::String>{"+neon"}},
- {"num-cores", 4}}}});
-
-#if TVM_LLVM_VERSION >= 110
-TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier")
- .set_config({{"kind", ffi::String("cuda")},
- {"arch", ffi::String("sm_72")},
- {"max_shared_memory_per_block", 49152},
- {"max_threads_per_block", 1024},
- {"thread_warp_size", 32},
- {"registers_per_block", 65536},
- {"host",
- ffi::Map<ffi::String, ffi::Any>{{"kind",
ffi::String("llvm")},
- {"mtriple",
ffi::String("aarch64-linux-gnu")},
- {"mcpu",
ffi::String("carmel")},
- {"num-cores", 8}}}});
-
-TVM_REGISTER_TARGET_TAG("nvidia/jetson-orin-nano")
- .set_config({{"kind", ffi::String("cuda")},
- {"arch", ffi::String("sm_87")},
- {"max_shared_memory_per_block", 49152},
- {"max_threads_per_block", 1024},
- {"thread_warp_size", 32},
- {"registers_per_block", 65536},
- {"host",
- ffi::Map<ffi::String, ffi::Any>{{"kind",
ffi::String("llvm")},
- {"mtriple",
ffi::String("aarch64-linux-gnu")},
- {"mcpu",
ffi::String("carmel")},
- {"num-cores", 6}}}});
-
-TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-32gb")
- .set_config({{"kind", ffi::String("cuda")},
- {"arch", ffi::String("sm_87")},
- {"max_shared_memory_per_block", 49152},
- {"max_threads_per_block", 1024},
- {"thread_warp_size", 32},
- {"registers_per_block", 65536},
- {"host",
- ffi::Map<ffi::String, ffi::Any>{{"kind",
ffi::String("llvm")},
- {"mtriple",
ffi::String("aarch64-linux-gnu")},
- {"mcpu",
ffi::String("cortex-a78")},
- {"num-cores", 8}}}});
-
-TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-64gb")
- .set_config({{"kind", ffi::String("cuda")},
- {"arch", ffi::String("sm_87")},
- {"max_shared_memory_per_block", 49152},
- {"max_threads_per_block", 1024},
- {"thread_warp_size", 32},
- {"registers_per_block", 65536},
- {"host",
- ffi::Map<ffi::String, ffi::Any>{{"kind",
ffi::String("llvm")},
- {"mtriple",
ffi::String("aarch64-linux-gnu")},
- {"mcpu",
ffi::String("cortex-a78")},
- {"num-cores", 12}}}});
-#endif // TVM_LLVM_VERSION >= 110
-#endif // TVM_LLVM_HAS_AARCH64_TARGET
-
-#define TVM_REGISTER_CUDA_TAG(Name, Arch, SharedMem, RegPerBlock) \
- TVM_REGISTER_TARGET_TAG(Name).set_config({ \
- {"kind", ffi::String("cuda")}, \
- {"keys", ffi::Array<ffi::String>{"cuda", "gpu"}}, \
- {"arch", ffi::String(Arch)}, \
- {"max_shared_memory_per_block", SharedMem}, \
- {"max_threads_per_block", 1024}, \
- {"thread_warp_size", 32}, \
- {"registers_per_block", RegPerBlock}, \
- })
-
-// Naming convention for CUDA tags see https://developer.nvidia.com/cuda-gpus
-// Parameters see Table 15. Technical Specifications per Compute Capability
-// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
-// Check `Maximum y- or z-dimension of a grid of thread blocks` for max
threads per block
-// Check `Maximum amount of shared memory per thread block` for max shared
memory per block
-// Note that above 48 KB requires dynamic shared memory
-TVM_REGISTER_CUDA_TAG("nvidia/tesla-k80", "sm_37", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/tesla-k40", "sm_35", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/tesla-k20", "sm_35", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2075", "sm_20", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2050", "sm_20", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2070", "sm_20", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a100", "sm_80", 49152, 65536)
- .with_config("l2_cache_size_bytes", 41943040);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-h100", "sm_90a", 49152, 65536)
- .with_config("l2_cache_size_bytes", 52428800);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-b100", "sm_100a", 49152, 65536)
- .with_config("l2_cache_size_bytes", 52428800);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a40", "sm_86", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a30", "sm_80", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a10", "sm_86", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a10g", "sm_86", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a16", "sm_86", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a2", "sm_86", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-t4", "sm_75", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-v100", "sm_70", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/tesla-p100", "sm_60", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/tesla-p40", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/tesla-p4", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/tesla-m60", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/tesla-m40", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/tesla-k80", "sm_37", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/tesla-k40", "sm_35", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/tesla-k20", "sm_35", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/tesla-k10", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/rtx-a6000", "sm_86", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-rtx-8000", "sm_75", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-rtx-6000", "sm_75", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-rtx-5000", "sm_75", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-rtx-4000", "sm_75", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-gv100", "sm_70", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-gp100", "sm_60", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p6000", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p5000", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p4000", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p2200", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p2000", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p1000", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p620", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p600", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p400", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-m6000-24gb", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-m6000", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k6000", "sm_35", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-m5000", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k5200", "sm_35", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k5000", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-m4000", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k4200", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k4000", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-m2000", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k2200", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k2000", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k2000d", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k1200", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k620", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k600", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k420", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-410", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-plex-7000", "sm_20", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/rtx-5000", "sm_75", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/rtx-4000", "sm_75", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/rtx-3000", "sm_75", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/t2000", "sm_75", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/t1000", "sm_75", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/p620", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/p520", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p5200", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p4200", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p3200", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p5000", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p4000", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p3000", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p2000", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p1000", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p600", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-p500", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-m5500m", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-m2200", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-m1200", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-m620", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-m520", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k6000m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k5200m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k5100m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-m5000m", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k500m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k4200m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k4100m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-m4000m", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k3100m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-m3000m", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k2200m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k2100m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-m2000m", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k1100m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-m1000m", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k620m", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k610m", "sm_35", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-m600m", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-k510m", "sm_35", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/quadro-m500m", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-nvs-810", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-nvs-510", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-nvs-315", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-nvs-310", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/nvs-5400m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/nvs-5200m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/nvs-4200m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4090", "sm_89", 49152, 65536)
- .with_config("l2_cache_size_bytes", 75497472);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-5060-ti", "sm_120", 49152, 65536)
- .with_config("l2_cache_size_bytes", 33554432);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090-ti", "sm_86", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090", "sm_86", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3080-ti", "sm_86", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3080", "sm_86", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3070-ti", "sm_86", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3070", "sm_86", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3060", "sm_86", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-titan-rtx", "sm_75", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-2080-ti", "sm_75", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-2080", "sm_75", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-2070", "sm_75", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-2060", "sm_75", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-titan-v", "sm_70", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-titan-xp", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/nvidia-titan-x", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-1080-ti", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-1080", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-1070-ti", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-1070", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-1060", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-1050", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-titan-x", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-titan-z", "sm_35", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-titan-black", "sm_35", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-titan", "sm_35", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-980-ti", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-980", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-970", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-960", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-950", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-780-ti", "sm_35", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-780", "sm_35", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-770", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-760", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-750-ti", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-750", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-690", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-680", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-670", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-660-ti", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-660", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-650-ti-boost", "sm_30", 49152,
65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-650-ti", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-650", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-560-ti", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-550-ti", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-460", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gts-450", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-590", "sm_20", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-580", "sm_20", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-570", "sm_20", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-480", "sm_20", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-470", "sm_20", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-465", "sm_20", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-740", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-730", "sm_35", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-730-ddr3,128bit", "sm_21", 49152,
32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-720", "sm_35", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-705", "sm_35", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-640-gddr5", "sm_35", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-640-gddr3", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-630", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-620", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-610", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-520", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-440", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-430", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-2080", "sm_75", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-2070", "sm_75", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-2060", "sm_75", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-1080", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-1070", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-1060", "sm_61", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-980", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-980m", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-970m", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-965m", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-960m", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-950m", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-940m", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-930m", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-920m", "sm_35", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-910m", "sm_52", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-880m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-870m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-860m-sm-30", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-860m-sm-50", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-850m", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-840m", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-830m", "sm_50", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-820m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-800m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-780m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-770m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-765m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-760m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-680mx", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-680m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-675mx", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-675m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-670mx", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-670m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-660m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-755m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-750m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-650m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-745m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-645m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-740m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-730m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-640m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-640m-le", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-735m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-635m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-730m", "sm_30", 49152, 65536);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-630m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-625m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-720m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-620m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-710m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-705m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-610m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-580m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-570m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-560m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-555m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-550m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-540m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-525m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-520mx", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-520m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-485m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-470m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-460m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-445m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-435m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-420m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-415m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-480m", "sm_20", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-710m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/geforce-410m", "sm_21", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/jetson-nano", "sm_53", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/jetson-tx2", "sm_62", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/jetson-tx1", "sm_53", 49152, 32768);
-TVM_REGISTER_CUDA_TAG("nvidia/tegra-x1", "sm_53", 49152, 32768);
-
-#undef TVM_REGISTER_CUDA_TAG
-
-#define TVM_REGISTER_TAG_AWS_C5(Name, Cores, Arch)
\
- TVM_REGISTER_TARGET_TAG(Name).set_config({{"kind", ffi::String("llvm")},
\
- {"keys",
ffi::Array<ffi::String>{"x86", "cpu"}}, \
- {"mcpu", ffi::String(Arch)},
\
- {"num-cores", Cores}});
-
-TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.large", 1, "skylake-avx512");
-TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.xlarge", 2, "skylake-avx512");
-TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.2xlarge", 4, "skylake-avx512");
-TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.4xlarge", 8, "skylake-avx512");
-TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.9xlarge", 18, "skylake-avx512");
-TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.12xlarge", 24, "cascadelake");
-TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.18xlarge", 36, "skylake-avx512");
-TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.24xlarge", 48, "cascadelake");
-
-#undef TVM_REGISTER_TAG_AWS_C5
-
-#if TVM_LLVM_VERSION >= 190
-#define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize)
\
- TVM_REGISTER_TARGET_TAG(Name).set_config(
\
- {{"kind", ffi::String("metal")},
\
- {"max_threads_per_block", ThreadsPerBlock},
\
- {"max_shared_memory_per_block", SharedMem},
\
- {"thread_warp_size", WarpSize},
\
- {"host", ffi::Map<ffi::String, ffi::Any>{{"kind", ffi::String("llvm")},
\
- {"mtriple",
ffi::String("arm64-apple-macos")}, \
- {"mcpu",
ffi::String("apple-m4")}}}});
-#else
-#define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize)
\
- TVM_REGISTER_TARGET_TAG(Name).set_config(
\
- {{"kind", ffi::String("metal")},
\
- {"max_threads_per_block", ThreadsPerBlock},
\
- {"max_shared_memory_per_block", SharedMem},
\
- {"thread_warp_size", WarpSize},
\
- {"host", ffi::Map<ffi::String, ffi::Any>{{"kind", ffi::String("llvm")},
\
- {"mtriple",
ffi::String("arm64-apple-macos")}, \
- {"mcpu",
ffi::String("apple-latest")}}}});
-#endif
-
-#if TVM_LLVM_HAS_AARCH64_TARGET
-TVM_REGISTER_METAL_GPU_TAG("apple/m1-gpu", 1024, 32768, 32);
-TVM_REGISTER_METAL_GPU_TAG("apple/m1-gpu-restricted", 256, 32768, 32);
-TVM_REGISTER_METAL_GPU_TAG("apple/m2-gpu", 1024, 32768, 32);
-#endif // TVM_LLVM_HAS_AARCH64_TARGET
-
-#undef TVM_REGISTER_METAL_TAG
-
} // namespace tvm
diff --git a/src/target/target.cc b/src/target/target.cc
index 89769ecaa6..277ae36bb6 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -291,6 +291,24 @@ ObjectPtr<TargetNode>
TargetInternal::FromConfig(ffi::Map<ffi::String, ffi::Any>
const ffi::String kFromDevice = "from_device";
ObjectPtr<TargetNode> target = ffi::make_object<TargetNode>();
+ // Step 0: If "tag" is present without "kind", look up the tag config and
merge overrides on top
+ if (!config.count(kKind) && config.count(kTag)) {
+ auto tag_name = config[kTag].try_cast<ffi::String>();
+ ICHECK(tag_name.has_value()) << "Expect type of field \"tag\" is String,
but get type: "
+ << config[kTag].GetTypeKey();
+ auto tag_config = TargetTag::GetConfig(tag_name.value());
+ ICHECK(tag_config.has_value()) << "Unknown target tag: " <<
tag_name.value();
+ // Start from the tag's base config, then apply user overrides
+ ffi::Map<ffi::String, ffi::Any> merged = tag_config.value();
+ for (const auto& kv : config) {
+ if (kv.first != kTag) {
+ merged.Set(kv.first, kv.second);
+ }
+ }
+ merged.Set(kTag, ffi::String(tag_name.value()));
+ config = std::move(merged);
+ }
+
// Step 1: Parse 'kind' (needed to look up the schema, but kept in config
for canonicalizer)
if (config.count(kKind)) {
if (auto kind = config[kKind].try_cast<ffi::String>()) {
diff --git a/tests/python/codegen/test_target_codegen_cuda.py
b/tests/python/codegen/test_target_codegen_cuda.py
index e689eabf15..6298349877 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -488,7 +488,7 @@ def test_cuda_const_float_to_half():
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_floordiv_with_vectorization():
- with tvm.target.cuda():
+ with tvm.target.Target("cuda"):
# B[i] = A[floordiv(i, k)]
n = 256
k = 37
@@ -521,7 +521,7 @@ def test_cuda_floordiv_with_vectorization():
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_floormod_with_vectorization():
- with tvm.target.cuda():
+ with tvm.target.Target("cuda"):
# B[i] = A[floormod(i, k)]
n = 256
k = 37
diff --git a/tests/python/codegen/test_target_codegen_hexagon.py
b/tests/python/codegen/test_target_codegen_hexagon.py
index a297e89bfb..4737e46730 100644
--- a/tests/python/codegen/test_target_codegen_hexagon.py
+++ b/tests/python/codegen/test_target_codegen_hexagon.py
@@ -35,7 +35,7 @@ def register_linker():
@tvm.testing.requires_hexagon
def test_basic():
- target = tvm.target.hexagon("v66", hvx=128)
+ target = tvm.target.Target("qcom/hexagon-v66")
@I.ir_module
class Module:
@@ -61,7 +61,7 @@ def test_basic():
@tvm.testing.requires_hexagon
def test_llvm_target_features():
- target = tvm.target.hexagon("v66", hvx=128)
+ target = tvm.target.Target("qcom/hexagon-v66")
@I.ir_module
class Module:
@@ -84,7 +84,17 @@ def test_llvm_target_features():
@tvm.testing.requires_hexagon
def test_llvm_options():
- target = tvm.target.hexagon("v66", llvm_options="-hexagon-noopt")
+ target = tvm.target.Target(
+ {
+ "kind": "hexagon",
+ "mtriple": "hexagon",
+ "mcpu": "hexagonv66",
+ "mattr": ["+hvxv66", "+hvx-length128b"],
+ "num-cores": 4,
+ "vtcm-capacity": 262144,
+ "llvm-options": ["-hexagon-noopt"],
+ }
+ )
@I.ir_module
class Module:
diff --git a/tests/python/contrib/test_hexagon/infrastructure.py
b/tests/python/contrib/test_hexagon/infrastructure.py
index d035b5d6fe..6603ada801 100644
--- a/tests/python/contrib/test_hexagon/infrastructure.py
+++ b/tests/python/contrib/test_hexagon/infrastructure.py
@@ -352,6 +352,24 @@ def quantize_np(arr_np: numpy.ndarray, dtype: str):
def get_hexagon_target(cpu_ver: str, **kwargs) -> tvm.target.Target:
- """Creates a Hexagon target"""
- target = tvm.target.hexagon(cpu_ver, **kwargs)
+ """Creates a Hexagon target from a registered tag.
+
+ Parameters
+ ----------
+ cpu_ver : str
+ Hexagon CPU version, e.g. "v68", "v69".
+ **kwargs :
+ Optional target attribute overrides (e.g. vtcm_capacity=1024).
+ """
+ tag = "qcom/hexagon-" + cpu_ver
+ if kwargs:
+ config = {"tag": tag}
+ if "vtcm_capacity" in kwargs:
+ config["vtcm-capacity"] = kwargs.pop("vtcm_capacity")
+ if "num_cores" in kwargs:
+ config["num-cores"] = kwargs.pop("num_cores")
+ config.update(kwargs)
+ target = tvm.target.Target(config)
+ else:
+ target = tvm.target.Target(tag)
return tvm.target.Target(target, host=target)
diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
index 5caa790754..f78dd67486 100644
--- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
+++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
@@ -269,7 +269,7 @@ def evaluate(
use_async_copy=0,
):
"""Evaluate function."""
- target_hexagon = tvm.target.hexagon("v68", link_params=True)
+ target_hexagon = tvm.target.Target("qcom/hexagon-v68")
with tvm.transform.PassContext(
config={
"tir.use_async_copy": use_async_copy,
diff --git a/tests/python/contrib/test_hexagon/test_dma_builtin.py
b/tests/python/contrib/test_hexagon/test_dma_builtin.py
index e0ded63c1f..8a6b1eee2a 100644
--- a/tests/python/contrib/test_hexagon/test_dma_builtin.py
+++ b/tests/python/contrib/test_hexagon/test_dma_builtin.py
@@ -151,7 +151,7 @@ class TestDMACopyWait:
@tvm.testing.requires_hexagon
def test_vtcm_alloc_compute(self, hexagon_launcher, mode, module):
- target_hexagon = tvm.target.hexagon("v69")
+ target_hexagon = tvm.target.Target("qcom/hexagon-v69")
target = tvm.target.Target(target_hexagon, host=target_hexagon)
with tvm.transform.PassContext(opt_level=3, config=[]):
ex = tvm.compile(mod=module, target=target, exec_mode=mode)
diff --git
a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py
b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py
index 277d9ed75d..a950302c64 100644
--- a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py
+++ b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py
@@ -76,7 +76,7 @@ def test_alloc_storage_with_scope_global(hexagon_launcher):
mod = Module
- target_hexagon = tvm.target.hexagon("v69", vtcm_capacity=4 * 2**20)
+ target_hexagon = tvm.target.Target({"tag": "qcom/hexagon-v69",
"vtcm-capacity": 4 * 2**20})
target = tvm.target.Target(target_hexagon, host=target_hexagon)
with tvm.transform.PassContext(opt_level=3):
lib = tvm.compile(mod, target, exec_mode="compiled")
diff --git a/tests/python/contrib/test_hexagon/test_relax_integration.py
b/tests/python/contrib/test_hexagon/test_relax_integration.py
index 4a3d122ce0..456576b3b0 100644
--- a/tests/python/contrib/test_hexagon/test_relax_integration.py
+++ b/tests/python/contrib/test_hexagon/test_relax_integration.py
@@ -46,7 +46,7 @@ def test_mobilenet_onnx(hexagon_session: Session):
shape_dict = {"input": data_np.shape}
relay_mod, _ = relay.frontend.from_onnx(onnx_model, shape_dict,
freeze_params=True)
- target_hexagon = tvm.target.hexagon("v68")
+ target_hexagon = tvm.target.Target("qcom/hexagon-v68")
target = tvm.target.Target(target_hexagon, host=target_hexagon)
relax_mod = onnx.from_onnx(onnx_model, shape_dict, freeze_params=True)
relax_mod = relay_translator.from_relay(relay_mod["main"], target_hexagon)
@@ -79,7 +79,7 @@ def test_mobilenet(hexagon_session: Session):
relay_mod, params = testing.mobilenet.get_workload(batch_size=1,
dtype="float32")
data_np = np.random.rand(1, 3, 224, 224).astype("float32")
- target_hexagon = tvm.target.hexagon("v68")
+ target_hexagon = tvm.target.Target("qcom/hexagon-v68")
target = tvm.target.Target(target_hexagon, host=target_hexagon)
# translate the relay mobilenet and bind params
diff --git a/tests/python/relax/backend/clml/utils.py
b/tests/python/relax/backend/clml/utils.py
index 2c167a2473..fd1d6fa9ac 100644
--- a/tests/python/relax/backend/clml/utils.py
+++ b/tests/python/relax/backend/clml/utils.py
@@ -92,7 +92,7 @@ def run_compare(mod, inputs, params_np):
ref = build_and_run(
mod,
inputs,
- tvm.target.adreno(),
+ tvm.target.Target("qcom/adreno-opencl"),
rpc=rpc,
load_path="vm_library_opencl.so",
)
@@ -101,7 +101,7 @@ def run_compare(mod, inputs, params_np):
out = build_and_run(
clml_mod,
inputs,
- tvm.target.adreno(cfg="clml"),
+ tvm.target.Target("qcom/adreno-opencl-clml"),
rpc=rpc,
load_path="vm_library_clml.so",
)
diff --git a/tests/python/relax/texture/adreno_utils.py
b/tests/python/relax/texture/adreno_utils.py
index 61f8c41ff1..46e95e9999 100644
--- a/tests/python/relax/texture/adreno_utils.py
+++ b/tests/python/relax/texture/adreno_utils.py
@@ -36,10 +36,13 @@ def get_target(backend, is_adreno=False):
tvm.target.Target
The target for the Adreno GPU.
"""
- target = tvm.target.adreno(backend=backend)
- if is_adreno:
- target = tvm.target.adreno(cfg="texture", backend=backend)
- return target
+ _TAG_MAP = {
+ ("opencl", False): "qcom/adreno-opencl",
+ ("opencl", True): "qcom/adreno-opencl-texture",
+ ("vulkan", False): "qcom/adreno-vulkan",
+ ("vulkan", True): "qcom/adreno-vulkan-texture",
+ }
+ return tvm.target.Target(_TAG_MAP[(backend, is_adreno)])
def get_rpc():
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_disallow_async_strided_mem_copy.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_disallow_async_strided_mem_copy.py
index 70deb2c1ce..f21d5ae6b8 100644
---
a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_disallow_async_strided_mem_copy.py
+++
b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_disallow_async_strided_mem_copy.py
@@ -24,7 +24,7 @@ from tvm.target import Target
def _target() -> Target:
- return Target("hexagon", host="llvm")
+ return Target("qcom/hexagon-v68", host="llvm")
def _create_context(mod, target) -> ms.TuneContext:
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_verify_vtcm_limit.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_verify_vtcm_limit.py
index ff4514f430..807f0a5fdb 100644
---
a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_verify_vtcm_limit.py
+++
b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_verify_vtcm_limit.py
@@ -111,7 +111,12 @@ class Conv2dNCHWcVTCM:
def test_conv2d_vtcm():
def get_target(vtcm_cap):
- target = tvm.target.hexagon("v68", vtcm_capacity=vtcm_cap)
+ target = tvm.target.Target({
+ "kind": "hexagon", "mtriple": "hexagon", "mcpu": "hexagonv68",
+ "mattr": ["+hvxv68", "+hvx-length128b", "+hvx-qfloat",
"-hvx-ieee-fp"],
+ "num-cores": 4, "vtcm-capacity": vtcm_cap,
+ "llvm-options": ["-force-hvx-float"],
+ })
return tvm.target.Target(target, host=target)
sch = tvm.s_tir.Schedule(Conv2dNCHWcVTCM, debug_mask="all")
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_mlt.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_mlt.py
index c2c3093d24..7500fae740 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_mlt.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_mlt.py
@@ -17,7 +17,7 @@
# pylint:
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import tvm.testing
from tvm.s_tir import meta_schedule as ms
-from tvm import target, te
+from tvm import te
from tvm.s_tir.meta_schedule.testing import te_workload
from tvm.s_tir.meta_schedule.testing.space_generation import (
check_sketches,
@@ -581,7 +581,7 @@ def test_multi_level_tiling_hexagon():
* weight[v_rh, v_rw, v_rc, v_co]
)
- target_hexagon = target.hexagon("v69", num_cores=4)
+ target_hexagon = Target("qcom/hexagon-v69")
I = 64
O = 64
diff --git a/tests/python/target/test_target_target.py
b/tests/python/target/test_target_target.py
index d9c31c6ae8..ce2f896b3c 100644
--- a/tests/python/target/test_target_target.py
+++ b/tests/python/target/test_target_target.py
@@ -19,7 +19,7 @@ import json
import pytest
import tvm
import tvm.testing
-from tvm.target import Target, arm_cpu, bifrost, cuda, intel_graphics, mali,
rocm
+from tvm.target import Target
def test_all_targets_device_type_verify():
@@ -43,14 +43,16 @@ def test_target_string_parse():
target = tvm.target.Target({"kind": "cuda", "model": "unknown", "libs":
["cublas", "cudnn"]})
assert target.kind.name == "cuda"
- assert target.model == "unknown"
+ assert target.attrs["model"] == "unknown"
assert set(target.keys) == set(["cuda", "gpu"])
- assert set(target.libs) == set(["cublas", "cudnn"])
- assert str(target) == str(tvm.target.cuda(libs=["cublas", "cudnn"]))
+ assert set(target.attrs["libs"]) == set(["cublas", "cudnn"])
- assert tvm.target.intel_graphics().device_name == "intel_graphics"
- assert tvm.target.mali().device_name == "mali"
- assert tvm.target.arm_cpu().device_name == "arm_cpu"
+ assert (
+ Target({"kind": "opencl", "device":
"intel_graphics"}).attrs.get("device", "")
+ == "intel_graphics"
+ )
+ assert Target({"kind": "opencl", "device": "mali"}).attrs.get("device",
"") == "mali"
+ assert Target({"kind": "llvm", "device": "arm_cpu"}).attrs.get("device",
"") == "arm_cpu"
def test_target_string_with_spaces():
@@ -89,12 +91,6 @@ def test_target_llvm_vector_width():
assert target.attrs["vector-width"] == 1024
-def test_target_create():
- targets = [cuda(), rocm(), mali(), intel_graphics(), arm_cpu("rk3399"),
bifrost()]
- for tgt in targets:
- assert tgt is not None
-
-
def test_target_config():
"""
Test that constructing a target from a dictionary works.
@@ -114,8 +110,8 @@ def test_target_config():
target = tvm.target.Target(config)
assert target.kind.name == "llvm"
assert all([key in target.keys for key in ["arm_cpu", "cpu"]])
- assert target.device_name == "arm_cpu"
- assert target.libs == ["cblas"]
+ assert target.attrs.get("device", "") == "arm_cpu"
+ assert list(target.attrs.get("libs", [])) == ["cblas"]
assert target.attrs["mfloat-abi"] == "hard"
assert all([attr in target.attrs["mattr"] for attr in ["+neon",
"-avx512f"]])
@@ -162,6 +158,20 @@ def test_target_tag_1():
assert tgt.attrs["registers_per_block"] == 32768
+def test_target_tag_override():
+ """Test creating a target from a tag with attribute overrides."""
+ tgt = tvm.target.Target({"tag": "nvidia/nvidia-a100",
"l2_cache_size_bytes": 12345})
+ assert tgt.kind.name == "cuda"
+ assert tgt.attrs["arch"] == "sm_80"
+ # Override should take effect
+ assert int(tgt.attrs["l2_cache_size_bytes"]) == 12345
+ # Base tag fields should be preserved
+ assert tgt.attrs["max_shared_memory_per_block"] == 49152
+ assert tgt.attrs["thread_warp_size"] == 32
+ # Tag name should be recorded
+ assert tgt.tag == "nvidia/nvidia-a100"
+
+
def test_list_kinds():
targets = tvm.target.Target.list_kinds()
assert len(targets) != 0
@@ -285,111 +295,6 @@ def test_target_with_host():
assert tgt.host.attrs["registers_per_block"] == 32768
-def test_canon_target_and_host_0():
- target = None
- host = None
- target, host = Target.canon_target_and_host(target, host)
- assert target is None
- assert host is None
-
-
-def test_canon_target_and_host_1():
- target = None
- host = "llvm"
- with pytest.raises(AssertionError, match=r"Target host is not empty when
target is empty."):
- target, host = Target.canon_target_and_host(target, host)
-
-
-def test_canon_target_and_host_2():
- target = Target("cuda")
- host = Target("llvm")
- target, host = Target.canon_target_and_host(target, host)
- assert target.kind.name == "cuda"
- assert target.host.kind.name == "llvm"
-
-
-def test_canon_target_and_host_3():
- target = Target(target="cuda", host="llvm")
- host = None
- target, host = Target.canon_target_and_host(target, host)
- assert target.kind.name == "cuda"
- assert target.host.kind.name == "llvm"
- assert host.kind.name == "llvm"
- assert target.host == host
-
-
-def test_canon_multi_target_and_host_0():
- with pytest.raises(AssertionError):
- Target.canon_multi_target_and_host(None)
-
-
-def test_canon_multi_target_and_host_1():
- raw_targets = Target.canon_multi_target_and_host({"kind": "llvm"})
- assert len(raw_targets) == 1
- assert raw_targets[0].kind.name == "llvm"
-
-
-def test_canon_multi_target_and_host_2():
- raw_targets = Target.canon_multi_target_and_host({1: "llvm", 2: "cuda"})
- assert len(raw_targets) == 2
- assert raw_targets[0].kind.name == "llvm"
- assert raw_targets[1].kind.name == "cuda"
-
-
-def test_canon_multi_target_and_host_3():
- raw_targets = Target.canon_multi_target_and_host(["llvm", "cuda"])
- assert len(raw_targets) == 2
- assert raw_targets[0].kind.name == "llvm"
- assert raw_targets[1].kind.name == "cuda"
-
-
-def test_canon_multi_target_and_host_4():
- raw_targets = Target.canon_multi_target_and_host("llvm")
- assert len(raw_targets) == 1
- assert raw_targets[0].kind.name == "llvm"
-
-
-def test_canon_multi_target_and_host_5():
- raw_targets = Target.canon_multi_target_and_host("cuda", "llvm")
- assert len(raw_targets) == 1
- assert raw_targets[0].kind.name == "cuda"
- assert raw_targets[0].host.kind.name == "llvm"
-
-
-def test_canon_multi_target_and_host_6():
- """Test `canon_target_and_host` by using TVM Objects"""
- cuda_device_type = tvm.device("cuda").dlpack_device_type()
- target = {cuda_device_type: Target(target="cuda", host="llvm")}
- host = None
- raw_targets_1 = Target.canon_multi_target_and_host(target, host)
- assert len(raw_targets_1) == 1
- assert raw_targets_1[0].kind.name == "cuda"
- assert raw_targets_1[0].host.kind.name == "llvm"
-
- target = {cuda_device_type: Target(tvm.runtime.container.String("cuda"))}
- host = Target(tvm.runtime.container.String("llvm"))
- target = tvm.runtime.convert(target)
- assert isinstance(target, tvm.ir.container.Map)
- raw_targets_2 = Target.canon_multi_target_and_host(target, host)
- assert len(raw_targets_2) == 1
- assert raw_targets_2[0].kind.name == "cuda"
- assert raw_targets_2[0].host.kind.name == "llvm"
-
-
-def test_canon_target_map_and_host():
- target_map = {"cuda": "cuda_module", "llvm": "cpu_module"}
- target_map, host = Target.canon_target_map_and_host(target_map, "llvm")
- assert host.kind.name == "llvm"
- for t, v in target_map.items():
- assert t.host.kind.name == "llvm"
- if t.kind.name == "cuda":
- assert v == "cuda_module"
- elif t.kind.name == "llvm":
- assert v == "cpu_module"
- else:
- assert False
-
-
def test_target_attr_bool_value():
target0 = Target({"kind": "vulkan", "supports_float16": True})
assert target0.attrs["supports_float16"] == 1
@@ -403,9 +308,9 @@ def test_target_attr_bool_value():
def test_target_attr_l2_cache_size_bytes():
target0 = Target("nvidia/nvidia-a100")
- assert target0.l2_cache_size_bytes == 41943040
+ assert int(target0.attrs.get("l2_cache_size_bytes", 0)) == 41943040
target1 = Target("nvidia/geforce-rtx-4090")
- assert target1.l2_cache_size_bytes == 75497472
+ assert int(target1.attrs.get("l2_cache_size_bytes", 0)) == 75497472
def test_target_features():
@@ -426,9 +331,9 @@ def test_target_from_device_cuda(input_device):
dev = tvm.cuda()
assert target.kind.name == "cuda"
assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block
- assert target.max_shared_memory_per_block ==
dev.max_shared_memory_per_block
- assert target.thread_warp_size == dev.warp_size
- assert target.arch == "sm_" + dev.compute_version.replace(".", "")
+ assert int(target.attrs["max_shared_memory_per_block"]) ==
dev.max_shared_memory_per_block
+ assert int(target.attrs["thread_warp_size"]) == dev.warp_size
+ assert str(target.attrs.get("arch", "")) == "sm_" +
dev.compute_version.replace(".", "")
@tvm.testing.requires_rocm
@@ -440,21 +345,21 @@ def test_target_from_device_rocm(input_device):
assert target.kind.name == "rocm"
assert target.attrs["mtriple"] == "amdgcn-and-amdhsa-hcc"
assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block
- assert target.max_shared_memory_per_block ==
dev.max_shared_memory_per_block
- assert target.thread_warp_size == dev.warp_size
+ assert int(target.attrs["max_shared_memory_per_block"]) ==
dev.max_shared_memory_per_block
+ assert int(target.attrs["thread_warp_size"]) == dev.warp_size
@tvm.testing.requires_vulkan
@pytest.mark.parametrize("input_device", ["vulkan", tvm.vulkan()])
-def test_target_from_device_rocm(input_device):
+def test_target_from_device_vulkan(input_device):
target = Target.from_device(input_device)
f_get_target_property =
tvm.get_global_func("device_api.vulkan.get_target_property")
dev = tvm.vulkan()
assert target.kind.name == "vulkan"
assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block
- assert target.max_shared_memory_per_block ==
dev.max_shared_memory_per_block
- assert target.thread_warp_size == dev.warp_size
+ assert int(target.attrs["max_shared_memory_per_block"]) ==
dev.max_shared_memory_per_block
+ assert int(target.attrs["thread_warp_size"]) == dev.warp_size
assert target.attrs["supports_float16"] == f_get_target_property(dev,
"supports_float16")
assert target.attrs["supports_int16"] == f_get_target_property(dev,
"supports_int16")
assert target.attrs["supports_int8"] == f_get_target_property(dev,
"supports_int8")
@@ -471,8 +376,8 @@ def test_target_from_device_opencl(input_device):
dev = tvm.opencl()
assert target.kind.name == "opencl"
assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block
- assert target.max_shared_memory_per_block ==
dev.max_shared_memory_per_block
- assert target.thread_warp_size == dev.warp_size
+ assert int(target.attrs["max_shared_memory_per_block"]) ==
dev.max_shared_memory_per_block
+ assert int(target.attrs["thread_warp_size"]) == dev.warp_size
def test_module_dict_from_deserialized_targets():
diff --git a/tests/python/target/test_x86_features.py
b/tests/python/target/test_x86_features.py
index cb810f7cab..709d216b56 100644
--- a/tests/python/target/test_x86_features.py
+++ b/tests/python/target/test_x86_features.py
@@ -183,7 +183,7 @@ def test_x86_target_features(min_llvm_version, tvm_target,
x86_feature, is_suppo
##
with Target(tvm_target):
- mcpu = Target.current(False).mcpu
+ mcpu = str(Target.current(False).attrs.get("mcpu", ""))
# check for feature via the python api (current context target)
assert target_has_features(x86_feature) == is_supported
# check for feature via the python api (with explicit target)
@@ -191,5 +191,5 @@ def test_x86_target_features(min_llvm_version, tvm_target,
x86_feature, is_suppo
# check for feature via the ffi llvm api (current context target)
(sum(_ffi_api.target_has_feature(feat, None) for feat in x86_feature)
> 0) == is_supported
# check for feature in target's llvm full x86 CPU feature list
- if (not Target(tvm_target).mattr) and isinstance(x86_feature, str):
+ if (not list(Target(tvm_target).attrs.get("mattr", []))) and
isinstance(x86_feature, str):
assert (x86_feature in codegen.llvm_get_cpu_features()) ==
is_supported