This is an automated email from the ASF dual-hosted git repository.

mousius pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d4bf9ecf55 [Target] Add target_device_type attribute to override 
default device_type (#12509)
d4bf9ecf55 is described below

commit d4bf9ecf5524d265916ac7b860b0027f5eee5c49
Author: Krzysztof Parzyszek <[email protected]>
AuthorDate: Fri Sep 30 06:17:42 2022 -0500

    [Target] Add target_device_type attribute to override default device_type 
(#12509)
    
    Implement Target::GetTargetDeviceType (C++) or get_target_device_type
    (python) to get the device type (kDL...) for a given target.
    
    The attribute "target_device_type" can be used to override the default
    device type associated with the target kind.
---
 include/tvm/target/compilation_config.h            |  2 +-
 include/tvm/target/target.h                        |  6 ++--
 include/tvm/target/target_kind.h                   | 15 +++++-----
 include/tvm/target/virtual_device.h                | 10 +++----
 python/tvm/micro/model_library_format.py           | 32 +++++++++++-----------
 python/tvm/relay/build_module.py                   |  4 +--
 python/tvm/relay/collage/collage.py                |  2 +-
 python/tvm/target/target.py                        |  4 +++
 src/auto_scheduler/search_policy/utils.h           | 15 ++++------
 src/auto_scheduler/search_task.cc                  |  4 +--
 src/driver/driver_api.cc                           |  8 ++++--
 src/relay/backend/build_module.cc                  |  4 +--
 src/relay/backend/contrib/uma/targets.cc           |  2 +-
 src/relay/backend/interpreter.cc                   |  6 ++--
 src/relay/backend/vm/compiler.cc                   |  6 ++--
 src/runtime/vulkan/vulkan_device.h                 |  2 +-
 src/target/compilation_config.cc                   | 23 ++++++++--------
 src/target/spirv/spirv_support.cc                  |  2 +-
 src/target/target.cc                               | 14 ++++++++--
 src/target/virtual_device.cc                       |  8 +++---
 src/tir/analysis/verify_memory.cc                  |  2 +-
 src/tir/transforms/make_packed_api.cc              |  2 +-
 src/tir/transforms/make_unpacked_api.cc            |  2 +-
 .../relay/collage/demo_collage_partitioner.py      | 10 +++----
 tests/python/unittest/test_target_target.py        |  4 ++-
 tests/scripts/release/PRERELEASE_NOTES.md          | 24 ++++++++++++++++
 26 files changed, 128 insertions(+), 85 deletions(-)

diff --git a/include/tvm/target/compilation_config.h 
b/include/tvm/target/compilation_config.h
index 53b7df88b8..eab34de1fb 100644
--- a/include/tvm/target/compilation_config.h
+++ b/include/tvm/target/compilation_config.h
@@ -78,7 +78,7 @@ class CompilationConfigNode : public Object {
    * It is possible to have multiple primitive targets for the same device 
type. However given
    * primitive targets left and right where:
    *  - left appears before right in the array
-   *  - left->kind->device_type == right->kind->device_type
+   *  - left->GetTargetDeviceType() == right->GetTargetDeviceType()
    * then:
    *  - right.IsExternalCodegenFor(left) must be true
    * In this way the \p FindPrimitiveTargetForDeviceOrFail method will find 
the 'most general'
diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h
index 6ad213f126..df6951685a 100644
--- a/include/tvm/target/target.h
+++ b/include/tvm/target/target.h
@@ -68,6 +68,8 @@ class TargetNode : public Object {
   TVM_DLL Map<String, ObjectRef> Export() const;
   /*! \return The Optional<Target> typed target host of the TargetNode */
   TVM_DLL Optional<Target> GetHost() const;
+  /*! \return The device type for this target */
+  TVM_DLL int GetTargetDeviceType() const;
 
   /*!
    * \brief Returns a human readable representation of \p Target which 
includes all fields,
@@ -230,11 +232,11 @@ class Target : public ObjectRef {
    * with \p that target. In particular:
    *  - \p this has a true ::tvm::attr::kIsExternalCodegen attribute
    *  - \p that does not have a true ::tvm::attr::kIsExternalCodegen attribute
-   *  - \p this and \p that have the same kind->device_type
+   *  - \p this and \p that have the same GetTargetDeviceType()
    *
    * After partitioning, the external codegen compilation path may use \p that 
to guide it's
    * compilation to a \p runtime::Module. Given \p this, an appropriate \p 
that can be
-   * found using \p 
CompilationConfig::FindPrimitiveTargetOrFail(this->kind->device_type).
+   * found using \p 
CompilationConfig::FindPrimitiveTargetOrFail(this->GetTargetDeviceType()).
    *
    * The \p CollagePartition pass uses this method to guide it's search over 
candidate partitions
    * using external codegen.
diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h
index 63c92fedbd..19bcce3116 100644
--- a/include/tvm/target/target_kind.h
+++ b/include/tvm/target/target_kind.h
@@ -92,7 +92,7 @@ class TargetKindNode : public Object {
   /*! \brief Name of the target kind */
   String name;
   /*! \brief Device type of target kind */
-  int device_type;
+  int default_device_type;
   /*! \brief Default keys of the target */
   Array<String> default_keys;
   /*! \brief Function used to preprocess on target creation */
@@ -102,7 +102,7 @@ class TargetKindNode : public Object {
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("name", &name);
-    v->Visit("device_type", &device_type);
+    v->Visit("default_device_type", &default_device_type);
     v->Visit("default_keys", &default_keys);
   }
 
@@ -211,7 +211,7 @@ class TargetKindRegEntry {
    * \brief Set DLPack's device_type the target
    * \param device_type Device type
    */
-  inline TargetKindRegEntry& set_device_type(int device_type);
+  inline TargetKindRegEntry& set_default_device_type(int device_type);
   /*!
    * \brief Set DLPack's device_type the target
    * \param keys The default keys
@@ -363,8 +363,8 @@ inline TargetKindRegEntry& 
TargetKindRegEntry::set_attr(const String& attr_name,
   return *this;
 }
 
-inline TargetKindRegEntry& TargetKindRegEntry::set_device_type(int 
device_type) {
-  kind_->device_type = device_type;
+inline TargetKindRegEntry& TargetKindRegEntry::set_default_device_type(int 
device_type) {
+  kind_->default_device_type = device_type;
   return *this;
 }
 
@@ -463,14 +463,15 @@ constexpr const char* kRelayToTIR = "RelayToTIR";
   TVM_STR_CONCAT(TVM_TARGET_KIND_REGISTER_VAR_DEF, __COUNTER__) = \
       ::tvm::TargetKindRegEntry::RegisterOrGet(TargetKindName)    \
           .set_name()                                             \
-          .set_device_type(DeviceType)                            \
+          .set_default_device_type(DeviceType)                    \
           .add_attr_option<Array<String>>("keys")                 \
           .add_attr_option<String>("tag")                         \
           .add_attr_option<String>("device")                      \
           .add_attr_option<String>("model")                       \
           .add_attr_option<Array<String>>("libs")                 \
           .add_attr_option<Target>("host")                        \
-          .add_attr_option<Integer>("from_device")
+          .add_attr_option<Integer>("from_device")                \
+          .add_attr_option<Integer>("target_device_type")
 
 }  // namespace tvm
 
diff --git a/include/tvm/target/virtual_device.h 
b/include/tvm/target/virtual_device.h
index 37f4b23b12..c26ae5befe 100644
--- a/include/tvm/target/virtual_device.h
+++ b/include/tvm/target/virtual_device.h
@@ -63,7 +63,7 @@ using MemoryScope = String;
  *
  * Some or all of these fields may be unconstrained, signaling that device 
planning is free to
  * choose a value consistent with the whole program. However if a \p target is 
given then the \p
- * device_type must equal \p target->kind->device_type.
+ * device_type must equal \p target->GetTargetDeviceType().
  *
  * Note that currently we assume if a function returns its result on a 
particular (virtual) device
  * then the function body is also executed on that device. See the overview 
comment in
@@ -167,8 +167,8 @@ class VirtualDeviceNode : public 
AttrsNode<VirtualDeviceNode> {
  private:
   /*!
    * \brief The \p DLDeviceType (represented as an int) of the virtual device. 
If \p target is
-   * known then this will be equal to \p target->kind->device_type. If \p 
target is null then the
-   * target is to be determined later.
+   * known then this will be equal to \p target->GetTargetDeviceType(). If \p 
target is null then
+   * the target is to be determined later.
    *
    * This is needed to support the legacy "on_device" and "device_copy" calls 
which only allow
    * a \p DLDeviceTypes (as an integer) to be given.
@@ -263,7 +263,7 @@ class VirtualDevice : public ObjectRef {
   /*!
    * \brief Construct a virtual device.
    * \param device_type The device type for the virtual device, or \p 
kInvalidDeviceType if
-   * unconstrained.  If \p target is defined then must match its \p 
target->kind->device_type.
+   * unconstrained.  If \p target is defined then must match its \p 
target->GetTargetDeviceType().
    * \param virtual_device_id The device id for the virtual device, or -1 if 
unconstrained.
    * \param target The target describing how to compile for the virtual 
device, or null if
    * unconstrained.
@@ -304,7 +304,7 @@ class VirtualDevice : public ObjectRef {
 
   /*! \brief Returns the \p VirtualDevice for \p target. */
   static VirtualDevice ForTarget(Target target) {
-    DLDeviceType device_type = 
static_cast<DLDeviceType>(target->kind->device_type);
+    DLDeviceType device_type = 
static_cast<DLDeviceType>(target->GetTargetDeviceType());
     return VirtualDevice(device_type, /*virtual_device_id=*/0, 
std::move(target));
   }
 
diff --git a/python/tvm/micro/model_library_format.py 
b/python/tvm/micro/model_library_format.py
index e220fa1ca5..1ba9f5e733 100644
--- a/python/tvm/micro/model_library_format.py
+++ b/python/tvm/micro/model_library_format.py
@@ -226,12 +226,12 @@ def _build_function_memory_map(function_metadata):
         for target in dict(finfo.workspace_sizes).keys():
             workspace_size = finfo.workspace_sizes[target]
             target_entry = {
-                "device": int(target.kind.device_type),
+                "device": int(target.get_target_device_type()),
                 "workspace_size_bytes": int(workspace_size),
             }
             target_local_entries[func_name].append(target_entry)
-            if workspace_size >= 
device_max_workspace.get(int(target.kind.device_type), 0):
-                device_max_workspace[int(target.kind.device_type)] = 
workspace_size
+            if workspace_size >= 
device_max_workspace.get(int(target.get_target_device_type()), 0):
+                device_max_workspace[int(target.get_target_device_type())] = 
workspace_size
 
     for func_name, target_entries_ in target_local_entries.items():
         func_entry = {
@@ -252,28 +252,28 @@ def _build_function_memory_map(function_metadata):
 
     for target in dict(main_func_metadata.workspace_sizes).keys():
         main_func_local_workspace = main_func_metadata.workspace_sizes[target]
-        target_main_entries[int(target.kind.device_type)] = 
_create_empty_entry(
-            int(target.kind.device_type)
+        target_main_entries[int(target.get_target_device_type())] = 
_create_empty_entry(
+            int(target.get_target_device_type())
         )
-        
target_main_entries[int(target.kind.device_type)]["workspace_size_bytes"] = int(
-            device_max_workspace.get(int(target.kind.device_type), 0)
+        
target_main_entries[int(target.get_target_device_type())]["workspace_size_bytes"]
 = int(
+            device_max_workspace.get(int(target.get_target_device_type()), 0)
         ) + int(main_func_local_workspace)
 
     for target in dict(main_func_metadata.constant_sizes).keys():
-        if int(target.kind.device_type) not in target_main_entries.keys():
-            target_main_entries[int(target.kind.device_type)] = 
_create_empty_entry(
-                int(target.kind.device_type)
+        if int(target.get_target_device_type()) not in 
target_main_entries.keys():
+            target_main_entries[int(target.get_target_device_type())] = 
_create_empty_entry(
+                int(target.get_target_device_type())
             )
-        
target_main_entries[int(target.kind.device_type)]["constants_size_bytes"] = int(
+        
target_main_entries[int(target.get_target_device_type())]["constants_size_bytes"]
 = int(
             main_func_metadata.constant_sizes[target]
         )
 
     for target in dict(main_func_metadata.io_sizes).keys():
-        if int(target.kind.device_type) not in target_main_entries.keys():
-            target_main_entries[int(target.kind.device_type)] = 
_create_empty_entry(
-                int(target.kind.device_type)
+        if int(target.get_target_device_type()) not in 
target_main_entries.keys():
+            target_main_entries[int(target.get_target_device_type())] = 
_create_empty_entry(
+                int(target.get_target_device_type())
             )
-        target_main_entries[int(target.kind.device_type)]["io_size_bytes"] = 
int(
+        
target_main_entries[int(target.get_target_device_type())]["io_size_bytes"] = 
int(
             main_func_metadata.io_sizes[target]
         )
 
@@ -483,7 +483,7 @@ def _write_tir_and_build_operator_memory_map(src_dir, 
targets, ir_module_by_targ
     memory_map = {}
     for target in targets:
         # TODO(mbs): The device type is not unique, better would be to use 
target.kind.name
-        target_device_type = target.kind.device_type
+        target_device_type = target.get_target_device_type()
         ir_mod = ir_module_by_target[target]
         printer = get_global_func("tir.ModelLibraryFormatPrinter")(False, 
None, False)
         with open(src_dir / f"tir-{target_device_type}.txt", "w") as f:
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index 6cdc79ceb5..112e5558fe 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -664,10 +664,10 @@ def create_executor(kind="debug", mod=None, device=None, 
target="llvm", params=N
     if mod is None:
         mod = IRModule()
     if device is not None:
-        assert device.device_type == raw_targets[0].kind.device_type
+        assert device.device_type == raw_targets[0].get_target_device_type()
     else:
         # Derive the default device from the first target.
-        device = _nd.device(raw_targets[0].kind.device_type, 0)
+        device = _nd.device(raw_targets[0].get_target_device_type(), 0)
 
     if params is not None:
         mod = IRModule.from_expr(bind_params_by_name(mod["main"], params))
diff --git a/python/tvm/relay/collage/collage.py 
b/python/tvm/relay/collage/collage.py
index 4dd59d56b4..632ab1746f 100644
--- a/python/tvm/relay/collage/collage.py
+++ b/python/tvm/relay/collage/collage.py
@@ -82,7 +82,7 @@ def vm_estimate_seconds(device, the_vm, func_name, args):
 def estimate_seconds(mod, target):
     """Returns the mean execution time of "main" in mod on target with params. 
The module
     may contain "Primitive" functions, possibly with "Compiler" attributes."""
-    device = tvm.device(target.kind.device_type)
+    device = tvm.device(target.get_target_device_type())
 
     try:
         # Build the module.
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index 1e9e2e698c..7081f992af 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -234,6 +234,10 @@ class Target(Object):
         """
         return _ffi_api.TargetKindGetAttr(self.kind, attr_name)
 
+    def get_target_device_type(self):
+        """Returns the device_type for this target."""
+        return _ffi_api.TargetGetDeviceType(self)
+
     @staticmethod
     def list_kinds():
         """Returns the list of available target names."""
diff --git a/src/auto_scheduler/search_policy/utils.h 
b/src/auto_scheduler/search_policy/utils.h
index ffd4bf4f48..44b60de1d7 100644
--- a/src/auto_scheduler/search_policy/utils.h
+++ b/src/auto_scheduler/search_policy/utils.h
@@ -48,27 +48,24 @@ namespace auto_scheduler {
 
 /*! \brief Return whether the search task is targeting a CPU. */
 inline bool IsCPUTask(const SearchTask& task) {
-  return (task)->target->kind->device_type == kDLCPU;
+  return (task)->target->GetTargetDeviceType() == kDLCPU;
 }
 
 /*! \brief Return whether the search task is targeting a GPU. */
 inline bool IsGPUTask(const SearchTask& task) {
-  return (task)->target->kind->device_type == kDLCUDA ||
-         (task)->target->kind->device_type == kDLOpenCL ||
-         (task)->target->kind->device_type == kDLVulkan ||
-         (task)->target->kind->device_type == kDLMetal ||
-         (task)->target->kind->device_type == kDLROCM ||
-         (task)->target->kind->device_type == kOpenGL;
+  int device_type = (task)->target->GetTargetDeviceType();
+  return device_type == kDLCUDA || device_type == kDLOpenCL || device_type == 
kDLVulkan ||
+         device_type == kDLMetal || device_type == kDLROCM || device_type == 
kOpenGL;
 }
 
 /*! \brief Return whether the search task is targeting a CUDA GPU. */
 inline bool IsCUDATask(const SearchTask& task) {
-  return (task)->target->kind->device_type == kDLCUDA;
+  return (task)->target->GetTargetDeviceType() == kDLCUDA;
 }
 
 /*! \brief Return whether the search task is targeting a OpenCL GPU. */
 inline bool IsOpenCLTask(const SearchTask& task) {
-  return (task)->target->kind->device_type == kDLOpenCL;
+  return (task)->target->GetTargetDeviceType() == kDLOpenCL;
 }
 
 /*! \brief Argsort. Order: largest to smallest */
diff --git a/src/auto_scheduler/search_task.cc 
b/src/auto_scheduler/search_task.cc
index 262340099c..5c8c678e8c 100755
--- a/src/auto_scheduler/search_task.cc
+++ b/src/auto_scheduler/search_task.cc
@@ -54,7 +54,7 @@ HardwareParams::HardwareParams(int num_cores, int 
vector_unit_bytes, int cache_l
 HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& 
target,
                                                             const Target& 
target_host) {
   // There is no use of target_host so no updates here in the function.
-  const auto device_type = target->kind->device_type;
+  const auto device_type = target->GetTargetDeviceType();
   if (device_type == kDLCPU) {
     return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64, 
0, 0, 0, 0, 0);
   } else if (device_type == kDLCUDA || device_type == kDLROCM) {
@@ -91,7 +91,7 @@ HardwareParams 
HardwareParamsNode::GetDefaultHardwareParams(const Target& target
     int max_vthread_extent = warp_size / 4;
     return HardwareParams(-1, 16, 64, max_shared_memory_per_block, 
max_local_memory_per_block,
                           max_threads_per_block, max_vthread_extent, 
warp_size);
-  } else if (target->kind->device_type == kDLOpenCL) {
+  } else if (target->GetTargetDeviceType() == kDLOpenCL) {
     if (target->GetAttr<String>("device", "") == "mali") {
       // We cannot use device API to get hardware attributes like CUDA,
       // because like Mali target is normally on the remote machine.
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index b460557da0..b0af0fb65e 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -72,7 +72,7 @@ bool ShouldAnnotateEntryFunc(const IRModule mod) {
 
 /*! \return The default host target for a given device target */
 Target DefaultTargetHost(Target target) {
-  if (target.defined() && target->kind->device_type == kDLCPU) {
+  if (target.defined() && target->GetTargetDeviceType() == kDLCPU) {
     return target;
   } else {
     if (LLVMEnabled()) {
@@ -423,7 +423,8 @@ runtime::Module TIRToRuntime(const Map<Target, IRModule>& 
inputs_arg,
 
   if (!target_host.defined()) {
     for (const auto& it : inputs) {
-      if (it.first->kind->device_type == kDLCPU || it.first->kind->device_type 
== kDLMicroDev) {
+      if (it.first->GetTargetDeviceType() == kDLCPU ||
+          it.first->GetTargetDeviceType() == kDLMicroDev) {
         target_host = it.first;
         break;
       }
@@ -460,7 +461,8 @@ runtime::Module TIRToRuntime(const Map<Target, IRModule>& 
inputs_arg,
       // unless they're supposed to. Here if we overrode the target host
       // to allow lowering previously we check that it's meant to be placed
       // back into the host Module.
-      bool overrides_host_target = target->kind->device_type == 
target_host->kind->device_type;
+      bool overrides_host_target =
+          target->GetTargetDeviceType() == target_host->GetTargetDeviceType();
       bool non_host_target_kind = target->kind != target_host->kind;
       if (overrides_host_target && non_host_target_kind) {
         device_modules.push_back(codegen::Build(host_mod, it.first));
diff --git a/src/relay/backend/build_module.cc 
b/src/relay/backend/build_module.cc
index 7b39cb4443..bca524794a 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -359,7 +359,7 @@ class RelayBuildModule : public runtime::ModuleNode {
     if (backend::IsAutoSchedulerEnabled() && 
config_->optional_homogeneous_target.defined()) {
       Pass major_pass = transform::AutoSchedulerLayoutRewrite();
       bool enable_layout_rewrite_targets =
-          config_->optional_homogeneous_target->kind->device_type == kDLCPU ||
+          config_->optional_homogeneous_target->GetTargetDeviceType() == 
kDLCPU ||
           config_->optional_homogeneous_target->GetAttr<String>("device", "") 
== "mali";
       if (enable_layout_rewrite_targets && 
pass_ctx.PassEnabled(major_pass->Info())) {
         With<Target> tctx(config_->optional_homogeneous_target);
@@ -373,7 +373,7 @@ class RelayBuildModule : public runtime::ModuleNode {
     if (backend::IsMetaScheduleEnabled() && 
config_->optional_homogeneous_target.defined()) {
       Pass major_pass = transform::MetaScheduleLayoutRewrite();
       bool enable_layout_rewrite_targets =
-          config_->optional_homogeneous_target->kind->device_type == kDLCPU ||
+          config_->optional_homogeneous_target->GetTargetDeviceType() == 
kDLCPU ||
           config_->optional_homogeneous_target->GetAttr<String>("device", "") 
== "mali";
       if (enable_layout_rewrite_targets && 
pass_ctx.PassEnabled(major_pass->Info())) {
         With<Target> tctx(config_->optional_homogeneous_target);
diff --git a/src/relay/backend/contrib/uma/targets.cc 
b/src/relay/backend/contrib/uma/targets.cc
index a17f6694f7..ed2cc047cf 100644
--- a/src/relay/backend/contrib/uma/targets.cc
+++ b/src/relay/backend/contrib/uma/targets.cc
@@ -50,7 +50,7 @@ 
TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget")
       auto target_kind =
           ::tvm::TargetKindRegEntry::RegisterOrGet(target_name)
               .set_name()
-              .set_device_type(kDLCPU)
+              .set_default_device_type(kDLCPU)
               .add_attr_option<Array<String>>("keys")
               .add_attr_option<String>("tag")
               .add_attr_option<String>("device")
diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc
index 65a0fdc948..1019ecf358 100644
--- a/src/relay/backend/interpreter.cc
+++ b/src/relay/backend/interpreter.cc
@@ -477,7 +477,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const 
Expr& n)>,
 
     // TODO(mbs): Take this from the host_virtual_device.
     Device shape_device;
-    shape_device.device_type = 
static_cast<DLDeviceType>(prim_shape_target->kind->device_type);
+    shape_device.device_type = 
static_cast<DLDeviceType>(prim_shape_target->GetTargetDeviceType());
     shape_device.device_id = 0;
 
     // 'Compile' the TIR shape function to appropriate callable form.
@@ -1017,7 +1017,7 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> 
EvalFunction(IRModule mod, Expr expr, De
           << PrettyPrint(mod) << "and expression:" << std::endl
           << PrettyPrint(expr);
 
-  ICHECK_EQ(device.device_type, target->kind->device_type);
+  ICHECK_EQ(device.device_type, target->GetTargetDeviceType());
   Array<Target> raw_targets = {target};
   CompilationConfig config(transform::PassContext::Current(), raw_targets);
 
@@ -1106,7 +1106,7 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> 
EvalFunction(IRModule mod, Expr expr, De
 ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
                std::unordered_set<String> import_set, Device device, Target 
target,
                Map<String, ObjectRef> attrs) {
-  ICHECK_EQ(device.device_type, target->kind->device_type);
+  ICHECK_EQ(device.device_type, target->GetTargetDeviceType());
   Array<Target> raw_targets = {target};
   CompilationConfig config(transform::PassContext::Current(), raw_targets);
 
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index a8bd3df32a..b807f41959 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -1067,7 +1067,7 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
   if (backend::IsAutoSchedulerEnabled() && 
config_->optional_homogeneous_target.defined()) {
     Pass major_pass = transform::AutoSchedulerLayoutRewrite();
     bool enable_layout_rewrite_targets =
-        config_->optional_homogeneous_target->kind->device_type == kDLCPU ||
+        config_->optional_homogeneous_target->GetTargetDeviceType() == kDLCPU 
||
         config_->optional_homogeneous_target->GetAttr<String>("device", "") == 
"mali";
     if (enable_layout_rewrite_targets && 
pass_ctx.PassEnabled(major_pass->Info())) {
       With<Target> tctx(config_->optional_homogeneous_target);
@@ -1081,7 +1081,7 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
   if (backend::IsMetaScheduleEnabled() && 
config_->optional_homogeneous_target.defined()) {
     Pass major_pass = transform::MetaScheduleLayoutRewrite();
     bool enable_layout_rewrite_targets =
-        config_->optional_homogeneous_target->kind->device_type == kDLCPU ||
+        config_->optional_homogeneous_target->GetTargetDeviceType() == kDLCPU 
||
         config_->optional_homogeneous_target->GetAttr<String>("device", "") == 
"mali";
     if (enable_layout_rewrite_targets && 
pass_ctx.PassEnabled(major_pass->Info())) {
       With<Target> tctx(config_->optional_homogeneous_target);
@@ -1164,7 +1164,7 @@ void VMCompiler::Codegen() {
   // Only the PrimFuncs will appear in per_target_modules, and there may 
legitimately be none.
   Map<Target, IRModule> per_tvm_target_modules = 
tec::GetPerTargetModules(context_.module);
   for (const auto& kv : per_tvm_target_modules) {
-    ICHECK(kv.first->kind->device_type != kDLExtDev);
+    ICHECK(kv.first->GetTargetDeviceType() != kDLExtDev);
   }
 
   // Retrieve all external runtime modules accumulated by external codegen 
(both function-at-a-time
diff --git a/src/runtime/vulkan/vulkan_device.h 
b/src/runtime/vulkan/vulkan_device.h
index a1257a732a..59ebf430e6 100644
--- a/src/runtime/vulkan/vulkan_device.h
+++ b/src/runtime/vulkan/vulkan_device.h
@@ -67,7 +67,7 @@ struct VulkanQueueInsertDebugUtilsLabelFunctions {
  * \brief Stores the capabilities/limits queried from the physical device.
  *
  * The member variables here have a 1-1 mapping to Target parameters,
- * if target->kind->device_type==kDLVulkan.  A separate struct is used
+ * if target->GetTargetDeviceType()==kDLVulkan.  A separate struct is used
  * to maintain the boundary between the Vulkan runtime in
  * libtvm_runtime.so, and the Target object in libtvm.so.
  */
diff --git a/src/target/compilation_config.cc b/src/target/compilation_config.cc
index 5e001921b0..a7f708f12a 100644
--- a/src/target/compilation_config.cc
+++ b/src/target/compilation_config.cc
@@ -42,13 +42,13 @@ Target 
CompilationConfigNode::FindPrimitiveTargetForDeviceOrFail(DLDeviceType de
   ICHECK_GT(device_type, 0) << "Invalid device type";
   auto itr = std::find_if(
       primitive_targets.begin(), primitive_targets.end(),
-      [device_type](const Target& target) { return target->kind->device_type 
== device_type; });
+      [device_type](const Target& target) { return 
target->GetTargetDeviceType() == device_type; });
   if (itr == primitive_targets.end()) {
     std::stringstream msg;
     msg << "No target is specified for device type " << device_type
         << ". The available device types and targets are:" << std::endl;
     for (const auto& target : primitive_targets) {
-      msg << "  " << target->kind->device_type << "-> " << 
target->ToDebugString() << std::endl;
+      msg << "  " << target->GetTargetDeviceType() << "-> " << 
target->ToDebugString() << std::endl;
     }
     LOG(FATAL) << msg.str();
   }
@@ -137,7 +137,7 @@ void CompilationConfigNode::Init(const 
transform::PassContext& pass_ctx,
   auto hosting_itr = std::find_if(raw_targets.begin(), raw_targets.end(), 
[](const Target& target) {
     // TODO(tvm-team): The kDLHexagon device can act as a host. We can remove 
kDLHexagon
     // here once we refactored kDLHexagon to kDLCPU.
-    return target->kind->device_type == kDLCPU || target->kind->device_type == 
kDLHexagon;
+    return target->GetTargetDeviceType() == kDLCPU || 
target->GetTargetDeviceType() == kDLHexagon;
   });
 
   // Any targets with their host field set?
@@ -149,23 +149,24 @@ void CompilationConfigNode::Init(const 
transform::PassContext& pass_ctx,
     // targets.
     host_target = Target((*has_host_itr)->GetHost().value(), 
/*host=*/Target());
     VLOG(1) << "The target " << (*has_host_itr)->ToDebugString() << " supplies 
a host target "
-            << host_target->ToDebugString() << " of device type " << 
host_target->kind->device_type;
+            << host_target->ToDebugString() << " of device type "
+            << host_target->GetTargetDeviceType();
   } else if (hosting_itr != raw_targets.end()) {
     // RULE B: If any raw target is for a device which could be a host then 
use the first such as
     // the host.
     host_target = Target(*hosting_itr, /*host=*/Target());
     VLOG(1) << "Using target " << host_target->ToDebugString() << " of 
CPU-like device type "
-            << host_target->kind->device_type << " as the host target";
+            << host_target->GetTargetDeviceType() << " as the host target";
   } else {
     // RULE C: Otherwise, create a default CPU host target.
     host_target = MakeDefaultCPUTarget();
     VLOG(1) << "Created a default target " << host_target->ToDebugString() << 
" of device type "
-            << host_target->kind->device_type << " for the host target";
+            << host_target->GetTargetDeviceType() << " for the host target";
   }
   ICHECK(host_target.defined());
   ICHECK(!host_target->host.defined());
 
-  if (host_target->kind->device_type != kDLCPU) {
+  if (host_target->GetTargetDeviceType() != kDLCPU) {
     // I think we're on thin ice here until we've audited the code base for 
assumed CPU hosts.
     VLOG(1) << "The host target is not a CPU. This is probably not going to 
work.";
   }
@@ -174,7 +175,7 @@ void CompilationConfigNode::Init(const 
transform::PassContext& pass_ctx,
   // Establish the host VirtualDevice.
   //
   host_virtual_device = virtual_device_cache_.Unique(
-      VirtualDevice(static_cast<DLDeviceType>(host_target->kind->device_type),
+      
VirtualDevice(static_cast<DLDeviceType>(host_target->GetTargetDeviceType()),
                     /*virtual_device_id=*/0, host_target));
   ICHECK(host_virtual_device.defined());
   ICHECK(host_virtual_device->target.defined());
@@ -205,7 +206,7 @@ void CompilationConfigNode::Init(const 
transform::PassContext& pass_ctx,
   std::unordered_set<DLDeviceType> primitive_target_device_types;
   std::unordered_set<std::string> kind_names;
   for (const auto& target : primitive_targets) {
-    
primitive_target_device_types.emplace(static_cast<DLDeviceType>(target->kind->device_type));
+    
primitive_target_device_types.emplace(static_cast<DLDeviceType>(target->GetTargetDeviceType()));
     CHECK(kind_names.emplace(target->kind->name).second) << "Multiple targets 
have been given"
                                                             "for the same 
device kind '"
                                                          << target->kind->name 
<< "'";
@@ -213,7 +214,7 @@ void CompilationConfigNode::Init(const 
transform::PassContext& pass_ctx,
   for (DLDeviceType device_type : primitive_target_device_types) {
     Target first_primitive_target;
     for (const auto& current_primitive_target : primitive_targets) {
-      if (current_primitive_target->kind->device_type != device_type) {
+      if (current_primitive_target->GetTargetDeviceType() != device_type) {
         continue;
       }
       if (!first_primitive_target.defined()) {
@@ -290,7 +291,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << "Primitive targets:";
       for (const auto& target : node->primitive_targets) {
         p->stream << std::endl
-                  << "  " << target->kind->device_type << " |-> " << 
target->ToDebugString();
+                  << "  " << target->GetTargetDeviceType() << " |-> " << 
target->ToDebugString();
       }
       p->stream << std::endl
                 << "Default primitive virtual device: " << 
node->default_primitive_virtual_device;
diff --git a/src/target/spirv/spirv_support.cc 
b/src/target/spirv/spirv_support.cc
index a91a2a3384..81b5cd8b8a 100644
--- a/src/target/spirv/spirv_support.cc
+++ b/src/target/spirv/spirv_support.cc
@@ -32,7 +32,7 @@ namespace tvm {
 namespace codegen {
 
 SPIRVSupport::SPIRVSupport(tvm::Target target) {
-  ICHECK_EQ(target->kind->device_type, kDLVulkan)
+  ICHECK_EQ(target->GetTargetDeviceType(), kDLVulkan)
       << "SPIRVSupport can only be checked for vulkan device type";
 
   if (target->GetAttr<Integer>("vulkan_api_version")) {
diff --git a/src/target/target.cc b/src/target/target.cc
index e3e9354a61..cbebd0e10c 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -622,7 +622,7 @@ bool Target::IsExternalCodegen() const {
 }
 
 bool Target::IsExternalCodegenFor(const Target& that) const {
-  return get()->kind->device_type == that->kind->device_type && 
IsExternalCodegen() &&
+  return get()->GetTargetDeviceType() == that->GetTargetDeviceType() && 
IsExternalCodegen() &&
          !that.IsExternalCodegen();
 }
 
@@ -665,6 +665,13 @@ Optional<Target> TargetNode::GetHost() const {
   return GetRef<Optional<Target>>(this->host.as<TargetNode>());
 }
 
+int TargetNode::GetTargetDeviceType() const {
+  if (Optional<Integer> device_type = GetAttr<Integer>("target_device_type")) {
+    return Downcast<Integer>(device_type)->value;
+  }
+  return kind->default_device_type;
+}
+
 String TargetNode::ToDebugString() const {
   std::ostringstream os;
   os << "Target(";
@@ -974,7 +981,7 @@ std::unordered_map<String, ObjectRef> 
TargetInternal::QueryDevice(int device_id,
                                                                   const 
TargetNode* target) {
   std::unordered_map<String, ObjectRef> output;
 
-  Device device{static_cast<DLDeviceType>(target->kind->device_type), 
device_id};
+  Device device{static_cast<DLDeviceType>(target->GetTargetDeviceType()), 
device_id};
 
   auto api = runtime::DeviceAPI::Get(device, true);
   if (!api) {
@@ -1042,6 +1049,9 @@ 
TVM_REGISTER_GLOBAL("target.TargetExitScope").set_body_typed(TargetInternal::Exi
 TVM_REGISTER_GLOBAL("target.TargetCurrent").set_body_typed(Target::Current);
 
TVM_REGISTER_GLOBAL("target.TargetExport").set_body_typed(TargetInternal::Export);
 
TVM_REGISTER_GLOBAL("target.WithHost").set_body_typed(TargetInternal::WithHost);
+TVM_REGISTER_GLOBAL("target.TargetGetDeviceType").set_body_typed([](const 
Target& target) {
+  return target->GetTargetDeviceType();
+});
 TVM_REGISTER_GLOBAL("target.TargetGetFeature")
     .set_body_typed([](const Target& target, const String& feature_key) {
       return target->GetFeature<ObjectRef>(feature_key);
diff --git a/src/target/virtual_device.cc b/src/target/virtual_device.cc
index ef01a2afda..39bb11ff15 100644
--- a/src/target/virtual_device.cc
+++ b/src/target/virtual_device.cc
@@ -68,9 +68,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 
 VirtualDevice::VirtualDevice(DLDeviceType device_type, int virtual_device_id, 
Target target,
                              MemoryScope memory_scope) {
-  ICHECK(!target.defined() || device_type == target->kind->device_type)
-      << "target " << target->ToDebugString() << " has device type " << 
target->kind->device_type
-      << " but virtual device has device type " << device_type;
+  ICHECK(!target.defined() || device_type == target->GetTargetDeviceType())
+      << "target " << target->ToDebugString() << " has device type "
+      << target->GetTargetDeviceType() << " but virtual device has device type 
" << device_type;
   auto node = make_object<VirtualDeviceNode>();
   node->device_type_int = device_type;
   node->virtual_device_id = virtual_device_id;
@@ -151,7 +151,7 @@ VirtualDevice VirtualDevice::Default(const VirtualDevice& 
lhs, const VirtualDevi
     defaulted_target = lhs->target;
   } else {
     // We can only default to the rhs's target if it is consistent with the 
device type
-    if (rhs->target.defined() && rhs->target->kind->device_type == 
defaulted_device_type) {
+    if (rhs->target.defined() && rhs->target->GetTargetDeviceType() == 
defaulted_device_type) {
       defaulted_target = rhs->target;
     }
     // else: leave as null
diff --git a/src/tir/analysis/verify_memory.cc 
b/src/tir/analysis/verify_memory.cc
index 6ee30e0470..80d6897011 100644
--- a/src/tir/analysis/verify_memory.cc
+++ b/src/tir/analysis/verify_memory.cc
@@ -186,7 +186,7 @@ std::vector<String> VerifyMemory_(const PrimFunc& func) {
 
   if (func->GetAttr<Integer>(tvm::attr::kCallingConv, 
Integer(CallingConv::kDefault)) ==
       CallingConv::kDefault) {
-    MemoryAccessVerifier v(func, target.value()->kind->device_type);
+    MemoryAccessVerifier v(func, target.value()->GetTargetDeviceType());
     v.Run();
     return v.Errors();
   } else {
diff --git a/src/tir/transforms/make_packed_api.cc 
b/src/tir/transforms/make_packed_api.cc
index bf7ff09c86..5b9bac03ab 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -145,7 +145,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func) {
 
   auto target = func->GetAttr<Target>(tvm::attr::kTarget);
   ICHECK(target.defined()) << "MakePackedAPI: Require the target attribute";
-  int target_device_type = target.value()->kind->device_type;
+  int target_device_type = target.value()->GetTargetDeviceType();
 
   std::string name_hint = global_symbol.value();
 
diff --git a/src/tir/transforms/make_unpacked_api.cc 
b/src/tir/transforms/make_unpacked_api.cc
index 87e8f38895..e44eb34068 100644
--- a/src/tir/transforms/make_unpacked_api.cc
+++ b/src/tir/transforms/make_unpacked_api.cc
@@ -50,7 +50,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) {
   auto* func_ptr = func.CopyOnWrite();
 
   // Setup device context
-  int target_device_type = target.value()->kind->device_type;
+  int target_device_type = target.value()->GetTargetDeviceType();
   Integer device_type(target_device_type);
   Integer device_id(0);
   PrimExpr node = StringImm("default");
diff --git a/tests/python/relay/collage/demo_collage_partitioner.py 
b/tests/python/relay/collage/demo_collage_partitioner.py
index 76db459d4c..c5a18c3832 100644
--- a/tests/python/relay/collage/demo_collage_partitioner.py
+++ b/tests/python/relay/collage/demo_collage_partitioner.py
@@ -280,7 +280,7 @@ def collage(model):
             logging.info("-------------- BEGIN PARTITIONED --------------")
             logging.info(partitioned_model["mod"])
             logging.info("-------------- END PARTITIONED ----------------")
-            dev = tvm.device(CUDA.kind.device_type)
+            dev = tvm.device(CUDA.get_target_device_type())
             compile_and_benchmark("collage", partitioned_model, targets, dev, 
tmp_dir)
 
 
@@ -309,7 +309,7 @@ def just_tensorrt(model):
         targets = []
         targets.append(CUDA)
         targets.append(trt_target)
-        dev = tvm.device(CUDA.kind.device_type)
+        dev = tvm.device(CUDA.get_target_device_type())
         compile_and_benchmark("just_tensorrt", partitioned_model, targets, 
dev, tmp_dir)
 
 
@@ -333,7 +333,7 @@ def just_cutlass(model):
             targets = []
             targets.append(CUDA)
             targets.append(tvm.target.Target(f"cutlass -tmp_dir={tmp_dir}", 
HOST))
-            dev = tvm.device(CUDA.kind.device_type)
+            dev = tvm.device(CUDA.get_target_device_type())
             compile_and_benchmark("just_cutlass", partitioned_model, targets, 
dev, tmp_dir)
 
 
@@ -346,7 +346,7 @@ def just_tvm(model):
     tmp_dir = tempfile.mkdtemp()
     autotvm_tune_module(model["mod"], CUDA, TUNING_LOG)
     with optional_tuning_records(TUNING_LOG):
-        dev = tvm.device(CUDA.kind.device_type)
+        dev = tvm.device(CUDA.get_target_device_type())
         compile_and_benchmark("just_tvm", model, CUDA, dev, tmp_dir)
 
 
@@ -360,7 +360,7 @@ def tvm_with_libs(model):
     cuda_target = tvm.target.Target("cuda -libs=cudnn,cublas", HOST)
     autotvm_tune_module(model["mod"], cuda_target, TUNING_LOG)
     with optional_tuning_records(TUNING_LOG):
-        dev = tvm.device(cuda_target.kind.device_type)
+        dev = tvm.device(cuda_target.get_target_device_type())
         compile_and_benchmark("tvm_with_libs", model, cuda_target, dev, 
tmp_dir)
 
 
diff --git a/tests/python/unittest/test_target_target.py 
b/tests/python/unittest/test_target_target.py
index d0dfa3942f..2b0f1b2dd7 100644
--- a/tests/python/unittest/test_target_target.py
+++ b/tests/python/unittest/test_target_target.py
@@ -58,7 +58,9 @@ def test_all_targets_device_type_verify():
         if tgt.kind.name not in tvm._ffi.runtime_ctypes.Device.STR2MASK:
             raise KeyError("Cannot find target kind: %s in Device.STR2MASK" % 
tgt.kind.name)
 
-        assert tgt.kind.device_type == 
tvm._ffi.runtime_ctypes.Device.STR2MASK[tgt.kind.name]
+        assert (
+            tgt.get_target_device_type() == 
tvm._ffi.runtime_ctypes.Device.STR2MASK[tgt.kind.name]
+        )
 
 
 def test_target_dispatch():
diff --git a/tests/scripts/release/PRERELEASE_NOTES.md 
b/tests/scripts/release/PRERELEASE_NOTES.md
new file mode 100644
index 0000000000..933d8d2720
--- /dev/null
+++ b/tests/scripts/release/PRERELEASE_NOTES.md
@@ -0,0 +1,24 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+Notable changes since last release
+----------------------------------
+
+* PR12509:
+   - Changed `TargetKind::device_type` to `TargetKind::default_device_type`.
+   - Introduced "target_default_device" attribute that overrides the default 
device.
+   - Added `Target::GetTargetDeviceType` to return the effective device type 
for the target.


Reply via email to