This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a9d86e61b6 [Metaschedule] Support tuning on rocm and vulkan target
(#11017)
a9d86e61b6 is described below
commit a9d86e61b650733128bbef9f2f3ddae01211dafb
Author: Masahiro Masuda <[email protected]>
AuthorDate: Fri Apr 15 15:11:41 2022 +0900
[Metaschedule] Support tuning on rocm and vulkan target (#11017)
---
python/tvm/meta_schedule/tune.py | 6 +++---
src/target/target_kind.cc | 5 +++++
2 files changed, 8 insertions(+), 3 deletions(-)
diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py
index 31130f67af..201434665a 100644
--- a/python/tvm/meta_schedule/tune.py
+++ b/python/tvm/meta_schedule/tune.py
@@ -411,7 +411,7 @@ class Parse:
# pylint: disable=protected-access
if target.kind.name == "llvm":
return DefaultLLVM._sch_rules()
- if target.kind.name == "cuda":
+ if target.kind.name in ["cuda", "rocm", "vulkan"]:
return DefaultCUDA._sch_rules()
# pylint: enable=protected-access
raise ValueError(f"Unsupported target: {target}")
@@ -425,7 +425,7 @@ class Parse:
# pylint: disable=protected-access
if target.kind.name == "llvm":
return DefaultLLVM._postproc()
- if target.kind.name == "cuda":
+ if target.kind.name in ["cuda", "rocm", "vulkan"]:
return DefaultCUDA._postproc()
# pylint: enable=protected-access
raise ValueError(f"Unsupported target: {target}")
@@ -444,7 +444,7 @@ class Parse:
# pylint: disable=protected-access
if target.kind.name == "llvm":
return DefaultLLVM._mutator_probs()
- if target.kind.name == "cuda":
+ if target.kind.name in ["cuda", "rocm", "vulkan"]:
return DefaultCUDA._mutator_probs()
# pylint: enable=protected-access
raise ValueError(f"Unsupported target: {target}")
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index 96c193d34a..2ad75259d6 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -308,7 +308,11 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
.add_attr_option<String>("mtriple")
.add_attr_option<Array<String>>("mattr")
.add_attr_option<Bool>("system-lib")
+ // TODO(masahi): Support querying from a target device
+ // On RDNA cards, thread_warp_size should be 32
.add_attr_option<Integer>("max_num_threads", Integer(256))
+ .add_attr_option<Integer>("max_threads_per_block", Integer(256))
+ .add_attr_option<Integer>("max_shared_memory_per_block", Integer(65536))
.add_attr_option<Integer>("thread_warp_size", Integer(64))
.set_default_keys({"rocm", "gpu"})
.set_attrs_preprocessor(UpdateROCmAttrs);
@@ -350,6 +354,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
.add_attr_option<Integer>("supported_subgroup_operations")
// Physical device limits
.add_attr_option<Integer>("max_num_threads", Integer(256))
+ .add_attr_option<Integer>("max_threads_per_block", Integer(256))
.add_attr_option<Integer>("thread_warp_size", Integer(1))
.add_attr_option<Integer>("max_block_size_x")
.add_attr_option<Integer>("max_block_size_y")