This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a9d86e61b6 [Metaschedule] Support tuning on rocm and vulkan target 
(#11017)
a9d86e61b6 is described below

commit a9d86e61b650733128bbef9f2f3ddae01211dafb
Author: Masahiro Masuda <[email protected]>
AuthorDate: Fri Apr 15 15:11:41 2022 +0900

    [Metaschedule] Support tuning on rocm and vulkan target (#11017)
---
 python/tvm/meta_schedule/tune.py | 6 +++---
 src/target/target_kind.cc        | 5 +++++
 2 files changed, 8 insertions(+), 3 deletions(-)

diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py
index 31130f67af..201434665a 100644
--- a/python/tvm/meta_schedule/tune.py
+++ b/python/tvm/meta_schedule/tune.py
@@ -411,7 +411,7 @@ class Parse:
         # pylint: disable=protected-access
         if target.kind.name == "llvm":
             return DefaultLLVM._sch_rules()
-        if target.kind.name == "cuda":
+        if target.kind.name in ["cuda", "rocm", "vulkan"]:
             return DefaultCUDA._sch_rules()
         # pylint: enable=protected-access
         raise ValueError(f"Unsupported target: {target}")
@@ -425,7 +425,7 @@ class Parse:
         # pylint: disable=protected-access
         if target.kind.name == "llvm":
             return DefaultLLVM._postproc()
-        if target.kind.name == "cuda":
+        if target.kind.name in ["cuda", "rocm", "vulkan"]:
             return DefaultCUDA._postproc()
         # pylint: enable=protected-access
         raise ValueError(f"Unsupported target: {target}")
@@ -444,7 +444,7 @@ class Parse:
         # pylint: disable=protected-access
         if target.kind.name == "llvm":
             return DefaultLLVM._mutator_probs()
-        if target.kind.name == "cuda":
+        if target.kind.name in ["cuda", "rocm", "vulkan"]:
             return DefaultCUDA._mutator_probs()
         # pylint: enable=protected-access
         raise ValueError(f"Unsupported target: {target}")
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index 96c193d34a..2ad75259d6 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -308,7 +308,11 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
     .add_attr_option<String>("mtriple")
     .add_attr_option<Array<String>>("mattr")
     .add_attr_option<Bool>("system-lib")
+    // TODO(masahi): Support querying from a target device
+    // On RDNA cards, thread_warp_size should be 32
     .add_attr_option<Integer>("max_num_threads", Integer(256))
+    .add_attr_option<Integer>("max_threads_per_block", Integer(256))
+    .add_attr_option<Integer>("max_shared_memory_per_block", Integer(65536))
     .add_attr_option<Integer>("thread_warp_size", Integer(64))
     .set_default_keys({"rocm", "gpu"})
     .set_attrs_preprocessor(UpdateROCmAttrs);
@@ -350,6 +354,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
     .add_attr_option<Integer>("supported_subgroup_operations")
     // Physical device limits
     .add_attr_option<Integer>("max_num_threads", Integer(256))
+    .add_attr_option<Integer>("max_threads_per_block", Integer(256))
     .add_attr_option<Integer>("thread_warp_size", Integer(1))
     .add_attr_option<Integer>("max_block_size_x")
     .add_attr_option<Integer>("max_block_size_y")

Reply via email to