This is an automated email from the ASF dual-hosted git repository.
mousius pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d4bf9ecf55 [Target] Add target_device_type attribute to override
default device_type (#12509)
d4bf9ecf55 is described below
commit d4bf9ecf5524d265916ac7b860b0027f5eee5c49
Author: Krzysztof Parzyszek <[email protected]>
AuthorDate: Fri Sep 30 06:17:42 2022 -0500
[Target] Add target_device_type attribute to override default device_type
(#12509)
Implement Target::GetTargetDeviceType (C++) or get_target_device_type
(python) to get the device type (kDL...) for a given target.
The attribute "target_device_type" can be used to override the default
device type associated with the target kind.
---
include/tvm/target/compilation_config.h | 2 +-
include/tvm/target/target.h | 6 ++--
include/tvm/target/target_kind.h | 15 +++++-----
include/tvm/target/virtual_device.h | 10 +++----
python/tvm/micro/model_library_format.py | 32 +++++++++++-----------
python/tvm/relay/build_module.py | 4 +--
python/tvm/relay/collage/collage.py | 2 +-
python/tvm/target/target.py | 4 +++
src/auto_scheduler/search_policy/utils.h | 15 ++++------
src/auto_scheduler/search_task.cc | 4 +--
src/driver/driver_api.cc | 8 ++++--
src/relay/backend/build_module.cc | 4 +--
src/relay/backend/contrib/uma/targets.cc | 2 +-
src/relay/backend/interpreter.cc | 6 ++--
src/relay/backend/vm/compiler.cc | 6 ++--
src/runtime/vulkan/vulkan_device.h | 2 +-
src/target/compilation_config.cc | 23 ++++++++--------
src/target/spirv/spirv_support.cc | 2 +-
src/target/target.cc | 14 ++++++++--
src/target/virtual_device.cc | 8 +++---
src/tir/analysis/verify_memory.cc | 2 +-
src/tir/transforms/make_packed_api.cc | 2 +-
src/tir/transforms/make_unpacked_api.cc | 2 +-
.../relay/collage/demo_collage_partitioner.py | 10 +++----
tests/python/unittest/test_target_target.py | 4 ++-
tests/scripts/release/PRERELEASE_NOTES.md | 24 ++++++++++++++++
26 files changed, 128 insertions(+), 85 deletions(-)
diff --git a/include/tvm/target/compilation_config.h
b/include/tvm/target/compilation_config.h
index 53b7df88b8..eab34de1fb 100644
--- a/include/tvm/target/compilation_config.h
+++ b/include/tvm/target/compilation_config.h
@@ -78,7 +78,7 @@ class CompilationConfigNode : public Object {
* It is possible to have multiple primitive targets for the same device
type. However given
* primitive targets left and right where:
* - left appears before right in the array
- * - left->kind->device_type == right->kind->device_type
+ * - left->GetTargetDeviceType() == right->GetTargetDeviceType()
* then:
* - right.IsExternalCodegenFor(left) must be true
* In this way the \p FindPrimitiveTargetForDeviceOrFail method will find
the 'most general'
diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h
index 6ad213f126..df6951685a 100644
--- a/include/tvm/target/target.h
+++ b/include/tvm/target/target.h
@@ -68,6 +68,8 @@ class TargetNode : public Object {
TVM_DLL Map<String, ObjectRef> Export() const;
/*! \return The Optional<Target> typed target host of the TargetNode */
TVM_DLL Optional<Target> GetHost() const;
+ /*! \return The device type for this target */
+ TVM_DLL int GetTargetDeviceType() const;
/*!
* \brief Returns a human readable representation of \p Target which
includes all fields,
@@ -230,11 +232,11 @@ class Target : public ObjectRef {
* with \p that target. In particular:
* - \p this has a true ::tvm::attr::kIsExternalCodegen attribute
* - \p that does not have a true ::tvm::attr::kIsExternalCodegen attribute
- * - \p this and \p that have the same kind->device_type
+ * - \p this and \p that have the same GetTargetDeviceType()
*
* After partitioning, the external codegen compilation path may use \p that
to guide it's
* compilation to a \p runtime::Module. Given \p this, an appropriate \p
that can be
- * found using \p
CompilationConfig::FindPrimitiveTargetOrFail(this->kind->device_type).
+ * found using \p
CompilationConfig::FindPrimitiveTargetOrFail(this->GetTargetDeviceType()).
*
* The \p CollagePartition pass uses this method to guide it's search over
candidate partitions
* using external codegen.
diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h
index 63c92fedbd..19bcce3116 100644
--- a/include/tvm/target/target_kind.h
+++ b/include/tvm/target/target_kind.h
@@ -92,7 +92,7 @@ class TargetKindNode : public Object {
/*! \brief Name of the target kind */
String name;
/*! \brief Device type of target kind */
- int device_type;
+ int default_device_type;
/*! \brief Default keys of the target */
Array<String> default_keys;
/*! \brief Function used to preprocess on target creation */
@@ -102,7 +102,7 @@ class TargetKindNode : public Object {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
- v->Visit("device_type", &device_type);
+ v->Visit("default_device_type", &default_device_type);
v->Visit("default_keys", &default_keys);
}
@@ -211,7 +211,7 @@ class TargetKindRegEntry {
* \brief Set DLPack's device_type the target
* \param device_type Device type
*/
- inline TargetKindRegEntry& set_device_type(int device_type);
+ inline TargetKindRegEntry& set_default_device_type(int device_type);
/*!
* \brief Set DLPack's device_type the target
* \param keys The default keys
@@ -363,8 +363,8 @@ inline TargetKindRegEntry&
TargetKindRegEntry::set_attr(const String& attr_name,
return *this;
}
-inline TargetKindRegEntry& TargetKindRegEntry::set_device_type(int
device_type) {
- kind_->device_type = device_type;
+inline TargetKindRegEntry& TargetKindRegEntry::set_default_device_type(int
device_type) {
+ kind_->default_device_type = device_type;
return *this;
}
@@ -463,14 +463,15 @@ constexpr const char* kRelayToTIR = "RelayToTIR";
TVM_STR_CONCAT(TVM_TARGET_KIND_REGISTER_VAR_DEF, __COUNTER__) = \
::tvm::TargetKindRegEntry::RegisterOrGet(TargetKindName) \
.set_name() \
- .set_device_type(DeviceType) \
+ .set_default_device_type(DeviceType) \
.add_attr_option<Array<String>>("keys") \
.add_attr_option<String>("tag") \
.add_attr_option<String>("device") \
.add_attr_option<String>("model") \
.add_attr_option<Array<String>>("libs") \
.add_attr_option<Target>("host") \
- .add_attr_option<Integer>("from_device")
+ .add_attr_option<Integer>("from_device") \
+ .add_attr_option<Integer>("target_device_type")
} // namespace tvm
diff --git a/include/tvm/target/virtual_device.h
b/include/tvm/target/virtual_device.h
index 37f4b23b12..c26ae5befe 100644
--- a/include/tvm/target/virtual_device.h
+++ b/include/tvm/target/virtual_device.h
@@ -63,7 +63,7 @@ using MemoryScope = String;
*
* Some or all of these fields may be unconstrained, signaling that device
planning is free to
* choose a value consistent with the whole program. However if a \p target is
given then the \p
- * device_type must equal \p target->kind->device_type.
+ * device_type must equal \p target->GetTargetDeviceType().
*
* Note that currently we assume if a function returns its result on a
particular (virtual) device
* then the function body is also executed on that device. See the overview
comment in
@@ -167,8 +167,8 @@ class VirtualDeviceNode : public
AttrsNode<VirtualDeviceNode> {
private:
/*!
* \brief The \p DLDeviceType (represented as an int) of the virtual device.
If \p target is
- * known then this will be equal to \p target->kind->device_type. If \p
target is null then the
- * target is to be determined later.
+ * known then this will be equal to \p target->GetTargetDeviceType(). If \p
target is null then
+ * the target is to be determined later.
*
* This is needed to support the legacy "on_device" and "device_copy" calls
which only allow
* a \p DLDeviceTypes (as an integer) to be given.
@@ -263,7 +263,7 @@ class VirtualDevice : public ObjectRef {
/*!
* \brief Construct a virtual device.
* \param device_type The device type for the virtual device, or \p
kInvalidDeviceType if
- * unconstrained. If \p target is defined then must match its \p
target->kind->device_type.
+ * unconstrained. If \p target is defined then must match its \p
target->GetTargetDeviceType().
* \param virtual_device_id The device id for the virtual device, or -1 if
unconstrained.
* \param target The target describing how to compile for the virtual
device, or null if
* unconstrained.
@@ -304,7 +304,7 @@ class VirtualDevice : public ObjectRef {
/*! \brief Returns the \p VirtualDevice for \p target. */
static VirtualDevice ForTarget(Target target) {
- DLDeviceType device_type =
static_cast<DLDeviceType>(target->kind->device_type);
+ DLDeviceType device_type =
static_cast<DLDeviceType>(target->GetTargetDeviceType());
return VirtualDevice(device_type, /*virtual_device_id=*/0,
std::move(target));
}
diff --git a/python/tvm/micro/model_library_format.py
b/python/tvm/micro/model_library_format.py
index e220fa1ca5..1ba9f5e733 100644
--- a/python/tvm/micro/model_library_format.py
+++ b/python/tvm/micro/model_library_format.py
@@ -226,12 +226,12 @@ def _build_function_memory_map(function_metadata):
for target in dict(finfo.workspace_sizes).keys():
workspace_size = finfo.workspace_sizes[target]
target_entry = {
- "device": int(target.kind.device_type),
+ "device": int(target.get_target_device_type()),
"workspace_size_bytes": int(workspace_size),
}
target_local_entries[func_name].append(target_entry)
- if workspace_size >=
device_max_workspace.get(int(target.kind.device_type), 0):
- device_max_workspace[int(target.kind.device_type)] =
workspace_size
+ if workspace_size >=
device_max_workspace.get(int(target.get_target_device_type()), 0):
+ device_max_workspace[int(target.get_target_device_type())] =
workspace_size
for func_name, target_entries_ in target_local_entries.items():
func_entry = {
@@ -252,28 +252,28 @@ def _build_function_memory_map(function_metadata):
for target in dict(main_func_metadata.workspace_sizes).keys():
main_func_local_workspace = main_func_metadata.workspace_sizes[target]
- target_main_entries[int(target.kind.device_type)] =
_create_empty_entry(
- int(target.kind.device_type)
+ target_main_entries[int(target.get_target_device_type())] =
_create_empty_entry(
+ int(target.get_target_device_type())
)
-
target_main_entries[int(target.kind.device_type)]["workspace_size_bytes"] = int(
- device_max_workspace.get(int(target.kind.device_type), 0)
+
target_main_entries[int(target.get_target_device_type())]["workspace_size_bytes"]
= int(
+ device_max_workspace.get(int(target.get_target_device_type()), 0)
) + int(main_func_local_workspace)
for target in dict(main_func_metadata.constant_sizes).keys():
- if int(target.kind.device_type) not in target_main_entries.keys():
- target_main_entries[int(target.kind.device_type)] =
_create_empty_entry(
- int(target.kind.device_type)
+ if int(target.get_target_device_type()) not in
target_main_entries.keys():
+ target_main_entries[int(target.get_target_device_type())] =
_create_empty_entry(
+ int(target.get_target_device_type())
)
-
target_main_entries[int(target.kind.device_type)]["constants_size_bytes"] = int(
+
target_main_entries[int(target.get_target_device_type())]["constants_size_bytes"]
= int(
main_func_metadata.constant_sizes[target]
)
for target in dict(main_func_metadata.io_sizes).keys():
- if int(target.kind.device_type) not in target_main_entries.keys():
- target_main_entries[int(target.kind.device_type)] =
_create_empty_entry(
- int(target.kind.device_type)
+ if int(target.get_target_device_type()) not in
target_main_entries.keys():
+ target_main_entries[int(target.get_target_device_type())] =
_create_empty_entry(
+ int(target.get_target_device_type())
)
- target_main_entries[int(target.kind.device_type)]["io_size_bytes"] =
int(
+
target_main_entries[int(target.get_target_device_type())]["io_size_bytes"] =
int(
main_func_metadata.io_sizes[target]
)
@@ -483,7 +483,7 @@ def _write_tir_and_build_operator_memory_map(src_dir,
targets, ir_module_by_targ
memory_map = {}
for target in targets:
# TODO(mbs): The device type is not unique, better would be to use
target.kind.name
- target_device_type = target.kind.device_type
+ target_device_type = target.get_target_device_type()
ir_mod = ir_module_by_target[target]
printer = get_global_func("tir.ModelLibraryFormatPrinter")(False,
None, False)
with open(src_dir / f"tir-{target_device_type}.txt", "w") as f:
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index 6cdc79ceb5..112e5558fe 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -664,10 +664,10 @@ def create_executor(kind="debug", mod=None, device=None,
target="llvm", params=N
if mod is None:
mod = IRModule()
if device is not None:
- assert device.device_type == raw_targets[0].kind.device_type
+ assert device.device_type == raw_targets[0].get_target_device_type()
else:
# Derive the default device from the first target.
- device = _nd.device(raw_targets[0].kind.device_type, 0)
+ device = _nd.device(raw_targets[0].get_target_device_type(), 0)
if params is not None:
mod = IRModule.from_expr(bind_params_by_name(mod["main"], params))
diff --git a/python/tvm/relay/collage/collage.py
b/python/tvm/relay/collage/collage.py
index 4dd59d56b4..632ab1746f 100644
--- a/python/tvm/relay/collage/collage.py
+++ b/python/tvm/relay/collage/collage.py
@@ -82,7 +82,7 @@ def vm_estimate_seconds(device, the_vm, func_name, args):
def estimate_seconds(mod, target):
"""Returns the mean execution time of "main" in mod on target with params.
The module
may contain "Primitive" functions, possibly with "Compiler" attributes."""
- device = tvm.device(target.kind.device_type)
+ device = tvm.device(target.get_target_device_type())
try:
# Build the module.
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index 1e9e2e698c..7081f992af 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -234,6 +234,10 @@ class Target(Object):
"""
return _ffi_api.TargetKindGetAttr(self.kind, attr_name)
+ def get_target_device_type(self):
+ """Returns the device_type for this target."""
+ return _ffi_api.TargetGetDeviceType(self)
+
@staticmethod
def list_kinds():
"""Returns the list of available target names."""
diff --git a/src/auto_scheduler/search_policy/utils.h
b/src/auto_scheduler/search_policy/utils.h
index ffd4bf4f48..44b60de1d7 100644
--- a/src/auto_scheduler/search_policy/utils.h
+++ b/src/auto_scheduler/search_policy/utils.h
@@ -48,27 +48,24 @@ namespace auto_scheduler {
/*! \brief Return whether the search task is targeting a CPU. */
inline bool IsCPUTask(const SearchTask& task) {
- return (task)->target->kind->device_type == kDLCPU;
+ return (task)->target->GetTargetDeviceType() == kDLCPU;
}
/*! \brief Return whether the search task is targeting a GPU. */
inline bool IsGPUTask(const SearchTask& task) {
- return (task)->target->kind->device_type == kDLCUDA ||
- (task)->target->kind->device_type == kDLOpenCL ||
- (task)->target->kind->device_type == kDLVulkan ||
- (task)->target->kind->device_type == kDLMetal ||
- (task)->target->kind->device_type == kDLROCM ||
- (task)->target->kind->device_type == kOpenGL;
+ int device_type = (task)->target->GetTargetDeviceType();
+ return device_type == kDLCUDA || device_type == kDLOpenCL || device_type ==
kDLVulkan ||
+ device_type == kDLMetal || device_type == kDLROCM || device_type ==
kOpenGL;
}
/*! \brief Return whether the search task is targeting a CUDA GPU. */
inline bool IsCUDATask(const SearchTask& task) {
- return (task)->target->kind->device_type == kDLCUDA;
+ return (task)->target->GetTargetDeviceType() == kDLCUDA;
}
/*! \brief Return whether the search task is targeting a OpenCL GPU. */
inline bool IsOpenCLTask(const SearchTask& task) {
- return (task)->target->kind->device_type == kDLOpenCL;
+ return (task)->target->GetTargetDeviceType() == kDLOpenCL;
}
/*! \brief Argsort. Order: largest to smallest */
diff --git a/src/auto_scheduler/search_task.cc
b/src/auto_scheduler/search_task.cc
index 262340099c..5c8c678e8c 100755
--- a/src/auto_scheduler/search_task.cc
+++ b/src/auto_scheduler/search_task.cc
@@ -54,7 +54,7 @@ HardwareParams::HardwareParams(int num_cores, int
vector_unit_bytes, int cache_l
HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target&
target,
const Target&
target_host) {
// There is no use of target_host so no updates here in the function.
- const auto device_type = target->kind->device_type;
+ const auto device_type = target->GetTargetDeviceType();
if (device_type == kDLCPU) {
return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64,
0, 0, 0, 0, 0);
} else if (device_type == kDLCUDA || device_type == kDLROCM) {
@@ -91,7 +91,7 @@ HardwareParams
HardwareParamsNode::GetDefaultHardwareParams(const Target& target
int max_vthread_extent = warp_size / 4;
return HardwareParams(-1, 16, 64, max_shared_memory_per_block,
max_local_memory_per_block,
max_threads_per_block, max_vthread_extent,
warp_size);
- } else if (target->kind->device_type == kDLOpenCL) {
+ } else if (target->GetTargetDeviceType() == kDLOpenCL) {
if (target->GetAttr<String>("device", "") == "mali") {
// We cannot use device API to get hardware attributes like CUDA,
// because like Mali target is normally on the remote machine.
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index b460557da0..b0af0fb65e 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -72,7 +72,7 @@ bool ShouldAnnotateEntryFunc(const IRModule mod) {
/*! \return The default host target for a given device target */
Target DefaultTargetHost(Target target) {
- if (target.defined() && target->kind->device_type == kDLCPU) {
+ if (target.defined() && target->GetTargetDeviceType() == kDLCPU) {
return target;
} else {
if (LLVMEnabled()) {
@@ -423,7 +423,8 @@ runtime::Module TIRToRuntime(const Map<Target, IRModule>&
inputs_arg,
if (!target_host.defined()) {
for (const auto& it : inputs) {
- if (it.first->kind->device_type == kDLCPU || it.first->kind->device_type
== kDLMicroDev) {
+ if (it.first->GetTargetDeviceType() == kDLCPU ||
+ it.first->GetTargetDeviceType() == kDLMicroDev) {
target_host = it.first;
break;
}
@@ -460,7 +461,8 @@ runtime::Module TIRToRuntime(const Map<Target, IRModule>&
inputs_arg,
// unless they're supposed to. Here if we overrode the target host
// to allow lowering previously we check that it's meant to be placed
// back into the host Module.
- bool overrides_host_target = target->kind->device_type ==
target_host->kind->device_type;
+ bool overrides_host_target =
+ target->GetTargetDeviceType() == target_host->GetTargetDeviceType();
bool non_host_target_kind = target->kind != target_host->kind;
if (overrides_host_target && non_host_target_kind) {
device_modules.push_back(codegen::Build(host_mod, it.first));
diff --git a/src/relay/backend/build_module.cc
b/src/relay/backend/build_module.cc
index 7b39cb4443..bca524794a 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -359,7 +359,7 @@ class RelayBuildModule : public runtime::ModuleNode {
if (backend::IsAutoSchedulerEnabled() &&
config_->optional_homogeneous_target.defined()) {
Pass major_pass = transform::AutoSchedulerLayoutRewrite();
bool enable_layout_rewrite_targets =
- config_->optional_homogeneous_target->kind->device_type == kDLCPU ||
+ config_->optional_homogeneous_target->GetTargetDeviceType() ==
kDLCPU ||
config_->optional_homogeneous_target->GetAttr<String>("device", "")
== "mali";
if (enable_layout_rewrite_targets &&
pass_ctx.PassEnabled(major_pass->Info())) {
With<Target> tctx(config_->optional_homogeneous_target);
@@ -373,7 +373,7 @@ class RelayBuildModule : public runtime::ModuleNode {
if (backend::IsMetaScheduleEnabled() &&
config_->optional_homogeneous_target.defined()) {
Pass major_pass = transform::MetaScheduleLayoutRewrite();
bool enable_layout_rewrite_targets =
- config_->optional_homogeneous_target->kind->device_type == kDLCPU ||
+ config_->optional_homogeneous_target->GetTargetDeviceType() ==
kDLCPU ||
config_->optional_homogeneous_target->GetAttr<String>("device", "")
== "mali";
if (enable_layout_rewrite_targets &&
pass_ctx.PassEnabled(major_pass->Info())) {
With<Target> tctx(config_->optional_homogeneous_target);
diff --git a/src/relay/backend/contrib/uma/targets.cc
b/src/relay/backend/contrib/uma/targets.cc
index a17f6694f7..ed2cc047cf 100644
--- a/src/relay/backend/contrib/uma/targets.cc
+++ b/src/relay/backend/contrib/uma/targets.cc
@@ -50,7 +50,7 @@
TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget")
auto target_kind =
::tvm::TargetKindRegEntry::RegisterOrGet(target_name)
.set_name()
- .set_device_type(kDLCPU)
+ .set_default_device_type(kDLCPU)
.add_attr_option<Array<String>>("keys")
.add_attr_option<String>("tag")
.add_attr_option<String>("device")
diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc
index 65a0fdc948..1019ecf358 100644
--- a/src/relay/backend/interpreter.cc
+++ b/src/relay/backend/interpreter.cc
@@ -477,7 +477,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const
Expr& n)>,
// TODO(mbs): Take this from the host_virtual_device.
Device shape_device;
- shape_device.device_type =
static_cast<DLDeviceType>(prim_shape_target->kind->device_type);
+ shape_device.device_type =
static_cast<DLDeviceType>(prim_shape_target->GetTargetDeviceType());
shape_device.device_id = 0;
// 'Compile' the TIR shape function to appropriate callable form.
@@ -1017,7 +1017,7 @@ TypedPackedFunc<ObjectRef(Array<Expr>)>
EvalFunction(IRModule mod, Expr expr, De
<< PrettyPrint(mod) << "and expression:" << std::endl
<< PrettyPrint(expr);
- ICHECK_EQ(device.device_type, target->kind->device_type);
+ ICHECK_EQ(device.device_type, target->GetTargetDeviceType());
Array<Target> raw_targets = {target};
CompilationConfig config(transform::PassContext::Current(), raw_targets);
@@ -1106,7 +1106,7 @@ TypedPackedFunc<ObjectRef(Array<Expr>)>
EvalFunction(IRModule mod, Expr expr, De
ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<String> import_set, Device device, Target
target,
Map<String, ObjectRef> attrs) {
- ICHECK_EQ(device.device_type, target->kind->device_type);
+ ICHECK_EQ(device.device_type, target->GetTargetDeviceType());
Array<Target> raw_targets = {target};
CompilationConfig config(transform::PassContext::Current(), raw_targets);
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index a8bd3df32a..b807f41959 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -1067,7 +1067,7 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
if (backend::IsAutoSchedulerEnabled() &&
config_->optional_homogeneous_target.defined()) {
Pass major_pass = transform::AutoSchedulerLayoutRewrite();
bool enable_layout_rewrite_targets =
- config_->optional_homogeneous_target->kind->device_type == kDLCPU ||
+ config_->optional_homogeneous_target->GetTargetDeviceType() == kDLCPU
||
config_->optional_homogeneous_target->GetAttr<String>("device", "") ==
"mali";
if (enable_layout_rewrite_targets &&
pass_ctx.PassEnabled(major_pass->Info())) {
With<Target> tctx(config_->optional_homogeneous_target);
@@ -1081,7 +1081,7 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
if (backend::IsMetaScheduleEnabled() &&
config_->optional_homogeneous_target.defined()) {
Pass major_pass = transform::MetaScheduleLayoutRewrite();
bool enable_layout_rewrite_targets =
- config_->optional_homogeneous_target->kind->device_type == kDLCPU ||
+ config_->optional_homogeneous_target->GetTargetDeviceType() == kDLCPU
||
config_->optional_homogeneous_target->GetAttr<String>("device", "") ==
"mali";
if (enable_layout_rewrite_targets &&
pass_ctx.PassEnabled(major_pass->Info())) {
With<Target> tctx(config_->optional_homogeneous_target);
@@ -1164,7 +1164,7 @@ void VMCompiler::Codegen() {
// Only the PrimFuncs will appear in per_target_modules, and there may
legitimately be none.
Map<Target, IRModule> per_tvm_target_modules =
tec::GetPerTargetModules(context_.module);
for (const auto& kv : per_tvm_target_modules) {
- ICHECK(kv.first->kind->device_type != kDLExtDev);
+ ICHECK(kv.first->GetTargetDeviceType() != kDLExtDev);
}
// Retrieve all external runtime modules accumulated by external codegen
(both function-at-a-time
diff --git a/src/runtime/vulkan/vulkan_device.h
b/src/runtime/vulkan/vulkan_device.h
index a1257a732a..59ebf430e6 100644
--- a/src/runtime/vulkan/vulkan_device.h
+++ b/src/runtime/vulkan/vulkan_device.h
@@ -67,7 +67,7 @@ struct VulkanQueueInsertDebugUtilsLabelFunctions {
* \brief Stores the capabilities/limits queried from the physical device.
*
* The member variables here have a 1-1 mapping to Target parameters,
- * if target->kind->device_type==kDLVulkan. A separate struct is used
+ * if target->GetTargetDeviceType()==kDLVulkan. A separate struct is used
* to maintain the boundary between the Vulkan runtime in
* libtvm_runtime.so, and the Target object in libtvm.so.
*/
diff --git a/src/target/compilation_config.cc b/src/target/compilation_config.cc
index 5e001921b0..a7f708f12a 100644
--- a/src/target/compilation_config.cc
+++ b/src/target/compilation_config.cc
@@ -42,13 +42,13 @@ Target
CompilationConfigNode::FindPrimitiveTargetForDeviceOrFail(DLDeviceType de
ICHECK_GT(device_type, 0) << "Invalid device type";
auto itr = std::find_if(
primitive_targets.begin(), primitive_targets.end(),
- [device_type](const Target& target) { return target->kind->device_type
== device_type; });
+ [device_type](const Target& target) { return
target->GetTargetDeviceType() == device_type; });
if (itr == primitive_targets.end()) {
std::stringstream msg;
msg << "No target is specified for device type " << device_type
<< ". The available device types and targets are:" << std::endl;
for (const auto& target : primitive_targets) {
- msg << " " << target->kind->device_type << "-> " <<
target->ToDebugString() << std::endl;
+ msg << " " << target->GetTargetDeviceType() << "-> " <<
target->ToDebugString() << std::endl;
}
LOG(FATAL) << msg.str();
}
@@ -137,7 +137,7 @@ void CompilationConfigNode::Init(const
transform::PassContext& pass_ctx,
auto hosting_itr = std::find_if(raw_targets.begin(), raw_targets.end(),
[](const Target& target) {
// TODO(tvm-team): The kDLHexagon device can act as a host. We can remove
kDLHexagon
// here once we refactored kDLHexagon to kDLCPU.
- return target->kind->device_type == kDLCPU || target->kind->device_type ==
kDLHexagon;
+ return target->GetTargetDeviceType() == kDLCPU ||
target->GetTargetDeviceType() == kDLHexagon;
});
// Any targets with their host field set?
@@ -149,23 +149,24 @@ void CompilationConfigNode::Init(const
transform::PassContext& pass_ctx,
// targets.
host_target = Target((*has_host_itr)->GetHost().value(),
/*host=*/Target());
VLOG(1) << "The target " << (*has_host_itr)->ToDebugString() << " supplies
a host target "
- << host_target->ToDebugString() << " of device type " <<
host_target->kind->device_type;
+ << host_target->ToDebugString() << " of device type "
+ << host_target->GetTargetDeviceType();
} else if (hosting_itr != raw_targets.end()) {
// RULE B: If any raw target is for a device which could be a host then
use the first such as
// the host.
host_target = Target(*hosting_itr, /*host=*/Target());
VLOG(1) << "Using target " << host_target->ToDebugString() << " of
CPU-like device type "
- << host_target->kind->device_type << " as the host target";
+ << host_target->GetTargetDeviceType() << " as the host target";
} else {
// RULE C: Otherwise, create a default CPU host target.
host_target = MakeDefaultCPUTarget();
VLOG(1) << "Created a default target " << host_target->ToDebugString() <<
" of device type "
- << host_target->kind->device_type << " for the host target";
+ << host_target->GetTargetDeviceType() << " for the host target";
}
ICHECK(host_target.defined());
ICHECK(!host_target->host.defined());
- if (host_target->kind->device_type != kDLCPU) {
+ if (host_target->GetTargetDeviceType() != kDLCPU) {
// I think we're on thin ice here until we've audited the code base for
assumed CPU hosts.
VLOG(1) << "The host target is not a CPU. This is probably not going to
work.";
}
@@ -174,7 +175,7 @@ void CompilationConfigNode::Init(const
transform::PassContext& pass_ctx,
// Establish the host VirtualDevice.
//
host_virtual_device = virtual_device_cache_.Unique(
- VirtualDevice(static_cast<DLDeviceType>(host_target->kind->device_type),
+
VirtualDevice(static_cast<DLDeviceType>(host_target->GetTargetDeviceType()),
/*virtual_device_id=*/0, host_target));
ICHECK(host_virtual_device.defined());
ICHECK(host_virtual_device->target.defined());
@@ -205,7 +206,7 @@ void CompilationConfigNode::Init(const
transform::PassContext& pass_ctx,
std::unordered_set<DLDeviceType> primitive_target_device_types;
std::unordered_set<std::string> kind_names;
for (const auto& target : primitive_targets) {
-
primitive_target_device_types.emplace(static_cast<DLDeviceType>(target->kind->device_type));
+
primitive_target_device_types.emplace(static_cast<DLDeviceType>(target->GetTargetDeviceType()));
CHECK(kind_names.emplace(target->kind->name).second) << "Multiple targets
have been given"
"for the same
device kind '"
<< target->kind->name
<< "'";
@@ -213,7 +214,7 @@ void CompilationConfigNode::Init(const
transform::PassContext& pass_ctx,
for (DLDeviceType device_type : primitive_target_device_types) {
Target first_primitive_target;
for (const auto& current_primitive_target : primitive_targets) {
- if (current_primitive_target->kind->device_type != device_type) {
+ if (current_primitive_target->GetTargetDeviceType() != device_type) {
continue;
}
if (!first_primitive_target.defined()) {
@@ -290,7 +291,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "Primitive targets:";
for (const auto& target : node->primitive_targets) {
p->stream << std::endl
- << " " << target->kind->device_type << " |-> " <<
target->ToDebugString();
+ << " " << target->GetTargetDeviceType() << " |-> " <<
target->ToDebugString();
}
p->stream << std::endl
<< "Default primitive virtual device: " <<
node->default_primitive_virtual_device;
diff --git a/src/target/spirv/spirv_support.cc
b/src/target/spirv/spirv_support.cc
index a91a2a3384..81b5cd8b8a 100644
--- a/src/target/spirv/spirv_support.cc
+++ b/src/target/spirv/spirv_support.cc
@@ -32,7 +32,7 @@ namespace tvm {
namespace codegen {
SPIRVSupport::SPIRVSupport(tvm::Target target) {
- ICHECK_EQ(target->kind->device_type, kDLVulkan)
+ ICHECK_EQ(target->GetTargetDeviceType(), kDLVulkan)
<< "SPIRVSupport can only be checked for vulkan device type";
if (target->GetAttr<Integer>("vulkan_api_version")) {
diff --git a/src/target/target.cc b/src/target/target.cc
index e3e9354a61..cbebd0e10c 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -622,7 +622,7 @@ bool Target::IsExternalCodegen() const {
}
bool Target::IsExternalCodegenFor(const Target& that) const {
- return get()->kind->device_type == that->kind->device_type &&
IsExternalCodegen() &&
+ return get()->GetTargetDeviceType() == that->GetTargetDeviceType() &&
IsExternalCodegen() &&
!that.IsExternalCodegen();
}
@@ -665,6 +665,13 @@ Optional<Target> TargetNode::GetHost() const {
return GetRef<Optional<Target>>(this->host.as<TargetNode>());
}
+int TargetNode::GetTargetDeviceType() const {
+ if (Optional<Integer> device_type = GetAttr<Integer>("target_device_type")) {
+ return Downcast<Integer>(device_type)->value;
+ }
+ return kind->default_device_type;
+}
+
String TargetNode::ToDebugString() const {
std::ostringstream os;
os << "Target(";
@@ -974,7 +981,7 @@ std::unordered_map<String, ObjectRef>
TargetInternal::QueryDevice(int device_id,
const
TargetNode* target) {
std::unordered_map<String, ObjectRef> output;
- Device device{static_cast<DLDeviceType>(target->kind->device_type),
device_id};
+ Device device{static_cast<DLDeviceType>(target->GetTargetDeviceType()),
device_id};
auto api = runtime::DeviceAPI::Get(device, true);
if (!api) {
@@ -1042,6 +1049,9 @@
TVM_REGISTER_GLOBAL("target.TargetExitScope").set_body_typed(TargetInternal::Exi
TVM_REGISTER_GLOBAL("target.TargetCurrent").set_body_typed(Target::Current);
TVM_REGISTER_GLOBAL("target.TargetExport").set_body_typed(TargetInternal::Export);
TVM_REGISTER_GLOBAL("target.WithHost").set_body_typed(TargetInternal::WithHost);
+TVM_REGISTER_GLOBAL("target.TargetGetDeviceType").set_body_typed([](const
Target& target) {
+ return target->GetTargetDeviceType();
+});
TVM_REGISTER_GLOBAL("target.TargetGetFeature")
.set_body_typed([](const Target& target, const String& feature_key) {
return target->GetFeature<ObjectRef>(feature_key);
diff --git a/src/target/virtual_device.cc b/src/target/virtual_device.cc
index ef01a2afda..39bb11ff15 100644
--- a/src/target/virtual_device.cc
+++ b/src/target/virtual_device.cc
@@ -68,9 +68,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
VirtualDevice::VirtualDevice(DLDeviceType device_type, int virtual_device_id,
Target target,
MemoryScope memory_scope) {
- ICHECK(!target.defined() || device_type == target->kind->device_type)
- << "target " << target->ToDebugString() << " has device type " <<
target->kind->device_type
- << " but virtual device has device type " << device_type;
+ ICHECK(!target.defined() || device_type == target->GetTargetDeviceType())
+ << "target " << target->ToDebugString() << " has device type "
+ << target->GetTargetDeviceType() << " but virtual device has device type
" << device_type;
auto node = make_object<VirtualDeviceNode>();
node->device_type_int = device_type;
node->virtual_device_id = virtual_device_id;
@@ -151,7 +151,7 @@ VirtualDevice VirtualDevice::Default(const VirtualDevice&
lhs, const VirtualDevi
defaulted_target = lhs->target;
} else {
// We can only default to the rhs's target if it is consistent with the
device type
- if (rhs->target.defined() && rhs->target->kind->device_type ==
defaulted_device_type) {
+ if (rhs->target.defined() && rhs->target->GetTargetDeviceType() ==
defaulted_device_type) {
defaulted_target = rhs->target;
}
// else: leave as null
diff --git a/src/tir/analysis/verify_memory.cc
b/src/tir/analysis/verify_memory.cc
index 6ee30e0470..80d6897011 100644
--- a/src/tir/analysis/verify_memory.cc
+++ b/src/tir/analysis/verify_memory.cc
@@ -186,7 +186,7 @@ std::vector<String> VerifyMemory_(const PrimFunc& func) {
if (func->GetAttr<Integer>(tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) ==
CallingConv::kDefault) {
- MemoryAccessVerifier v(func, target.value()->kind->device_type);
+ MemoryAccessVerifier v(func, target.value()->GetTargetDeviceType());
v.Run();
return v.Errors();
} else {
diff --git a/src/tir/transforms/make_packed_api.cc
b/src/tir/transforms/make_packed_api.cc
index bf7ff09c86..5b9bac03ab 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -145,7 +145,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func) {
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "MakePackedAPI: Require the target attribute";
- int target_device_type = target.value()->kind->device_type;
+ int target_device_type = target.value()->GetTargetDeviceType();
std::string name_hint = global_symbol.value();
diff --git a/src/tir/transforms/make_unpacked_api.cc
b/src/tir/transforms/make_unpacked_api.cc
index 87e8f38895..e44eb34068 100644
--- a/src/tir/transforms/make_unpacked_api.cc
+++ b/src/tir/transforms/make_unpacked_api.cc
@@ -50,7 +50,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) {
auto* func_ptr = func.CopyOnWrite();
// Setup device context
- int target_device_type = target.value()->kind->device_type;
+ int target_device_type = target.value()->GetTargetDeviceType();
Integer device_type(target_device_type);
Integer device_id(0);
PrimExpr node = StringImm("default");
diff --git a/tests/python/relay/collage/demo_collage_partitioner.py
b/tests/python/relay/collage/demo_collage_partitioner.py
index 76db459d4c..c5a18c3832 100644
--- a/tests/python/relay/collage/demo_collage_partitioner.py
+++ b/tests/python/relay/collage/demo_collage_partitioner.py
@@ -280,7 +280,7 @@ def collage(model):
logging.info("-------------- BEGIN PARTITIONED --------------")
logging.info(partitioned_model["mod"])
logging.info("-------------- END PARTITIONED ----------------")
- dev = tvm.device(CUDA.kind.device_type)
+ dev = tvm.device(CUDA.get_target_device_type())
compile_and_benchmark("collage", partitioned_model, targets, dev,
tmp_dir)
@@ -309,7 +309,7 @@ def just_tensorrt(model):
targets = []
targets.append(CUDA)
targets.append(trt_target)
- dev = tvm.device(CUDA.kind.device_type)
+ dev = tvm.device(CUDA.get_target_device_type())
compile_and_benchmark("just_tensorrt", partitioned_model, targets,
dev, tmp_dir)
@@ -333,7 +333,7 @@ def just_cutlass(model):
targets = []
targets.append(CUDA)
targets.append(tvm.target.Target(f"cutlass -tmp_dir={tmp_dir}",
HOST))
- dev = tvm.device(CUDA.kind.device_type)
+ dev = tvm.device(CUDA.get_target_device_type())
compile_and_benchmark("just_cutlass", partitioned_model, targets,
dev, tmp_dir)
@@ -346,7 +346,7 @@ def just_tvm(model):
tmp_dir = tempfile.mkdtemp()
autotvm_tune_module(model["mod"], CUDA, TUNING_LOG)
with optional_tuning_records(TUNING_LOG):
- dev = tvm.device(CUDA.kind.device_type)
+ dev = tvm.device(CUDA.get_target_device_type())
compile_and_benchmark("just_tvm", model, CUDA, dev, tmp_dir)
@@ -360,7 +360,7 @@ def tvm_with_libs(model):
cuda_target = tvm.target.Target("cuda -libs=cudnn,cublas", HOST)
autotvm_tune_module(model["mod"], cuda_target, TUNING_LOG)
with optional_tuning_records(TUNING_LOG):
- dev = tvm.device(cuda_target.kind.device_type)
+ dev = tvm.device(cuda_target.get_target_device_type())
compile_and_benchmark("tvm_with_libs", model, cuda_target, dev,
tmp_dir)
diff --git a/tests/python/unittest/test_target_target.py
b/tests/python/unittest/test_target_target.py
index d0dfa3942f..2b0f1b2dd7 100644
--- a/tests/python/unittest/test_target_target.py
+++ b/tests/python/unittest/test_target_target.py
@@ -58,7 +58,9 @@ def test_all_targets_device_type_verify():
if tgt.kind.name not in tvm._ffi.runtime_ctypes.Device.STR2MASK:
raise KeyError("Cannot find target kind: %s in Device.STR2MASK" %
tgt.kind.name)
- assert tgt.kind.device_type ==
tvm._ffi.runtime_ctypes.Device.STR2MASK[tgt.kind.name]
+ assert (
+ tgt.get_target_device_type() ==
tvm._ffi.runtime_ctypes.Device.STR2MASK[tgt.kind.name]
+ )
def test_target_dispatch():
diff --git a/tests/scripts/release/PRERELEASE_NOTES.md
b/tests/scripts/release/PRERELEASE_NOTES.md
new file mode 100644
index 0000000000..933d8d2720
--- /dev/null
+++ b/tests/scripts/release/PRERELEASE_NOTES.md
@@ -0,0 +1,24 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements. See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership. The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License. You may obtain a copy of the License at -->
+
+<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied. See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+Notable changes since last release
+----------------------------------
+
+* PR12509:
+ - Changed `TargetKind::device_type` to `TargetKind::default_device_type`.
+ - Introduced "target_default_device" attribute that overrides the default
device.
+ - Added `Target::GetTargetDeviceType` to return the effective device type
for the target.