This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d1ede36cad [Target][Device] Auto detect target and create device from
str in torch style (#15714)
d1ede36cad is described below
commit d1ede36cadb38b42a4558034ef38917f843cdb22
Author: Lesheng Jin <[email protected]>
AuthorDate: Wed Sep 13 11:43:04 2023 -0700
[Target][Device] Auto detect target and create device from str in torch
style (#15714)
- Target auto detection: `Target.auto_detect()`.
- Target created from device: `Target.from_device("cuda")` or
`Target.from_device(tvm.cuda())`
- create device from str: `tvm.device("cuda:0")` or tvm.device("cuda",
0)
---
python/tvm/runtime/ndarray.py | 19 ++++-
python/tvm/target/detect_target.py | 114 ++++++++++++++++++++++++++++
python/tvm/target/target.py | 24 ++++++
tests/python/unittest/test_device.py | 71 +++++++++++++++++
tests/python/unittest/test_target_target.py | 45 +++++++++++
5 files changed, 271 insertions(+), 2 deletions(-)
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index a78c68ee67..6f0d1f440a 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -324,11 +324,26 @@ def device(dev_type, dev_id=0):
assert tvm.device("cpu", 1) == tvm.cpu(1)
assert tvm.device("cuda", 0) == tvm.cuda(0)
"""
+ if not isinstance(dev_id, int):
+ raise ValueError(f"Invalid device id: {dev_id}")
+
if isinstance(dev_type, string_types):
dev_type = dev_type.split()[0]
+ if dev_type.count(":") == 0:
+ pass
+ elif dev_type.count(":") == 1:
+ # It will override the dev_id passed by the user.
+ dev_type, dev_id = dev_type.split(":")
+ if not dev_id.isdigit():
+ raise ValueError(f"Invalid device id: {dev_id}")
+ dev_id = int(dev_id)
+ else:
+ raise ValueError(f"Invalid device string: {dev_type}")
+
if dev_type not in Device.STR2MASK:
- raise ValueError(f"Unknown device type {dev_type}")
- dev_type = Device.STR2MASK[dev_type]
+ raise ValueError(f"Unknown device type: {dev_type}")
+
+ return Device(Device.STR2MASK[dev_type], dev_id)
return Device(dev_type, dev_id)
diff --git a/python/tvm/target/detect_target.py
b/python/tvm/target/detect_target.py
new file mode 100644
index 0000000000..5c139cc949
--- /dev/null
+++ b/python/tvm/target/detect_target.py
@@ -0,0 +1,114 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Detect target."""
+from typing import Union
+
+from . import Target
+from .._ffi import get_global_func
+from .._ffi.runtime_ctypes import Device
+from ..runtime.ndarray import device
+
+
+def _detect_metal(dev: Device) -> Target:
+ return Target(
+ {
+ "kind": "metal",
+ "max_shared_memory_per_block": 32768,
+ "max_threads_per_block": dev.max_threads_per_block,
+ "thread_warp_size": dev.warp_size,
+ }
+ )
+
+
+def _detect_cuda(dev: Device) -> Target:
+ return Target(
+ {
+ "kind": "cuda",
+ "max_shared_memory_per_block": dev.max_shared_memory_per_block,
+ "max_threads_per_block": dev.max_threads_per_block,
+ "thread_warp_size": dev.warp_size,
+ "arch": "sm_" + dev.compute_version.replace(".", ""),
+ }
+ )
+
+
+def _detect_rocm(dev: Device) -> Target:
+ return Target(
+ {
+ "kind": "rocm",
+ "mtriple": "amdgcn-and-amdhsa-hcc",
+ "max_shared_memory_per_block": dev.max_shared_memory_per_block,
+ "max_threads_per_block": dev.max_threads_per_block,
+ "thread_warp_size": dev.warp_size,
+ }
+ )
+
+
+def _detect_vulkan(dev: Device) -> Target:
+ f_get_target_property =
get_global_func("device_api.vulkan.get_target_property")
+ return Target(
+ {
+ "kind": "vulkan",
+ "max_threads_per_block": dev.max_threads_per_block,
+ "max_shared_memory_per_block": dev.max_shared_memory_per_block,
+ "thread_warp_size": dev.warp_size,
+ "supports_float16": f_get_target_property(dev, "supports_float16"),
+ "supports_int16": f_get_target_property(dev, "supports_int16"),
+ "supports_int8": f_get_target_property(dev, "supports_int8"),
+ "supports_16bit_buffer": f_get_target_property(dev,
"supports_16bit_buffer"),
+ }
+ )
+
+
+def detect_target_from_device(dev: Union[str, Device]) -> Target:
+ """Detects Target associated with the given device. If the device does not
exist,
+ there will be an Error.
+
+ Parameters
+ ----------
+ dev : Union[str, Device]
+ The device to detect the target for.
+ Supported device types: ["cuda", "metal", "rocm", "vulkan"]
+
+ Returns
+ -------
+ target : Target
+ The detected target.
+ """
+ if isinstance(dev, str):
+ dev = device(dev)
+ device_type = Device.MASK2STR[dev.device_type]
+ if device_type not in SUPPORT_DEVICE:
+ raise ValueError(
+ f"Auto detection for device `{device_type}` is not supported. "
+ f"Currently only supports: {SUPPORT_DEVICE.keys()}"
+ )
+ if not dev.exist:
+ raise ValueError(
+ f"Cannot detect device `{dev}`. Please make sure the device and
its driver "
+ "is installed properly, and TVM is compiled with the driver"
+ )
+ target = SUPPORT_DEVICE[device_type](dev)
+ return target
+
+
+SUPPORT_DEVICE = {
+ "cuda": _detect_cuda,
+ "metal": _detect_metal,
+ "vulkan": _detect_vulkan,
+ "rocm": _detect_rocm,
+}
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index b027b99b17..ec74cbcdb6 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -18,9 +18,11 @@
import json
import re
import warnings
+from typing import Union
import tvm._ffi
from tvm._ffi import register_func as _register_func
+from tvm._ffi.runtime_ctypes import Device
from tvm.runtime import Object, convert
from tvm.runtime.container import String
from tvm.ir.container import Map, Array
@@ -148,6 +150,28 @@ class Target(Object):
def with_host(self, host=None):
return _ffi_api.WithHost(self, Target(host))
+ @staticmethod
+ def from_device(device: Union[str, Device]) -> "Target":
+ """Detects Target associated with the given device. If the device does
not exist,
+ there will be an Error.
+
+ Parameters
+ ----------
+ dev : Union[str, Device]
+ The device to detect the target for.
+ Supported device types: ["cuda", "metal", "rocm", "vulkan",
"opencl", "cpu"]
+
+ Returns
+ -------
+ target : Target
+ The detected target.
+ """
+ from .detect_target import ( # pylint: disable=import-outside-toplevel
+ detect_target_from_device,
+ )
+
+ return detect_target_from_device(device)
+
@staticmethod
def current(allow_none=True):
"""Returns the current target.
diff --git a/tests/python/unittest/test_device.py
b/tests/python/unittest/test_device.py
new file mode 100644
index 0000000000..9d10251e15
--- /dev/null
+++ b/tests/python/unittest/test_device.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+import tvm
+import tvm.testing
+from tvm._ffi.runtime_ctypes import Device
+
+
[email protected](
+ "dev_str, expected_device_type, expect_device_id",
+ [
+ ("cpu", Device.kDLCPU, 0),
+ ("cuda", Device.kDLCUDA, 0),
+ ("cuda:0", Device.kDLCUDA, 0),
+ ("cuda:3", Device.kDLCUDA, 3),
+ ("metal:2", Device.kDLMetal, 2),
+ ],
+)
+def test_device(dev_str, expected_device_type, expect_device_id):
+ dev = tvm.device(dev_str)
+ assert dev.device_type == expected_device_type
+ assert dev.device_id == expect_device_id
+
+
[email protected](
+ "dev_type, dev_id, expected_device_type, expect_device_id",
+ [
+ ("cpu", 0, Device.kDLCPU, 0),
+ ("cuda", 0, Device.kDLCUDA, 0),
+ (Device.kDLCUDA, 0, Device.kDLCUDA, 0),
+ ("cuda", 3, Device.kDLCUDA, 3),
+ (Device.kDLMetal, 2, Device.kDLMetal, 2),
+ ],
+)
+def test_device_with_dev_id(dev_type, dev_id, expected_device_type,
expect_device_id):
+ dev = tvm.device(dev_type=dev_type, dev_id=dev_id)
+ assert dev.device_type == expected_device_type
+ assert dev.device_id == expect_device_id
+
+
[email protected](
+ "dev_type, dev_id",
+ [
+ ("cpu:0:0", None),
+ ("cpu:?", None),
+ ("cpu:", None),
+ (Device.kDLCUDA, "?"),
+ ],
+)
+def test_deive_error(dev_type, dev_id):
+ with pytest.raises(ValueError):
+ dev = tvm.device(dev_type=dev_type, dev_id=dev_id)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/unittest/test_target_target.py
b/tests/python/unittest/test_target_target.py
index 2b0f1b2dd7..da1bbc2c21 100644
--- a/tests/python/unittest/test_target_target.py
+++ b/tests/python/unittest/test_target_target.py
@@ -488,5 +488,50 @@ def test_target_features():
assert not target_with_features.features.is_missing
[email protected]_cuda
[email protected]("input_device", ["cuda", tvm.cuda()])
+def test_target_from_device_cuda(input_device):
+ target = Target.from_device(input_device)
+
+ dev = tvm.cuda()
+ assert target.kind.name == "cuda"
+ assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block
+ assert target.max_shared_memory_per_block ==
dev.max_shared_memory_per_block
+ assert target.thread_warp_size == dev.warp_size
+ assert target.arch == "sm_" + dev.compute_version.replace(".", "")
+
+
[email protected]_rocm
[email protected]("input_device", ["rocm", tvm.rocm()])
+def test_target_from_device_rocm(input_device):
+ target = Target.from_device(input_device)
+
+ dev = tvm.rocm()
+ assert target.kind.name == "rocm"
+ assert target.attrs["mtriple"] == "amdgcn-and-amdhsa-hcc"
+ assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block
+ assert target.max_shared_memory_per_block ==
dev.max_shared_memory_per_block
+ assert target.thread_warp_size == dev.warp_size
+
+
[email protected]_vulkan
[email protected]("input_device", ["vulkan", tvm.vulkan()])
+def test_target_from_device_rocm(input_device):
+ target = Target.from_device(input_device)
+
+ f_get_target_property =
tvm.get_global_func("device_api.vulkan.get_target_property")
+ dev = tvm.vulkan()
+ assert target.kind.name == "vulkan"
+ assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block
+ assert target.max_shared_memory_per_block ==
dev.max_shared_memory_per_block
+ assert target.thread_warp_size == dev.warp_size
+ assert target.attrs["supports_float16"] == f_get_target_property(dev,
"supports_float16")
+ assert target.attrs["supports_int16"] == f_get_target_property(dev,
"supports_int16")
+ assert target.attrs["supports_int8"] == f_get_target_property(dev,
"supports_int8")
+ assert target.attrs["supports_16bit_buffer"] == f_get_target_property(
+ dev, "supports_16bit_buffer"
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()