This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4b906554af [OpenCL] Add OpenCL device for automatic target detection
(#16854)
4b906554af is described below
commit 4b906554af2bad9859b405f694b1c59d77d74785
Author: Mengshiun Yu <[email protected]>
AuthorDate: Thu Apr 11 19:12:23 2024 +0800
[OpenCL] Add OpenCL device for automatic target detection (#16854)
This PR adds OpenCL device for automatic target detection.
---
python/tvm/target/detect_target.py | 14 +++++++++++++-
tests/python/target/test_target_target.py | 12 ++++++++++++
2 files changed, 25 insertions(+), 1 deletion(-)
diff --git a/python/tvm/target/detect_target.py
b/python/tvm/target/detect_target.py
index a2fe5e1f8b..b23baa0313 100644
--- a/python/tvm/target/detect_target.py
+++ b/python/tvm/target/detect_target.py
@@ -58,6 +58,17 @@ def _detect_rocm(dev: Device) -> Target:
)
+def _detect_opencl(dev: Device) -> Target:
+ return Target(
+ {
+ "kind": "opencl",
+ "max_shared_memory_per_block": dev.max_shared_memory_per_block,
+ "max_threads_per_block": dev.max_threads_per_block,
+ "thread_warp_size": dev.warp_size,
+ }
+ )
+
+
def _detect_vulkan(dev: Device) -> Target:
f_get_target_property =
get_global_func("device_api.vulkan.get_target_property")
return Target(
@@ -100,7 +111,7 @@ def detect_target_from_device(dev: Union[str, Device]) ->
Target:
----------
dev : Union[str, Device]
The device to detect the target for.
- Supported device types: ["cuda", "metal", "rocm", "vulkan"]
+ Supported device types: ["cuda", "metal", "rocm", "vulkan", "opencl"]
Returns
-------
@@ -129,4 +140,5 @@ SUPPORT_DEVICE = {
"metal": _detect_metal,
"vulkan": _detect_vulkan,
"rocm": _detect_rocm,
+ "opencl": _detect_opencl,
}
diff --git a/tests/python/target/test_target_target.py
b/tests/python/target/test_target_target.py
index 83bd864970..e977ef10aa 100644
--- a/tests/python/target/test_target_target.py
+++ b/tests/python/target/test_target_target.py
@@ -547,5 +547,17 @@ def test_target_from_device_rocm(input_device):
)
[email protected]_opencl
[email protected]("input_device", ["opencl", tvm.opencl()])
+def test_target_from_device_opencl(input_device):
+ target = Target.from_device(input_device)
+
+ dev = tvm.opencl()
+ assert target.kind.name == "opencl"
+ assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block
+ assert target.max_shared_memory_per_block ==
dev.max_shared_memory_per_block
+ assert target.thread_warp_size == dev.warp_size
+
+
if __name__ == "__main__":
tvm.testing.main()