This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d1ede36cad [Target][Device] Auto detect target and create device from 
str in torch style (#15714)
d1ede36cad is described below

commit d1ede36cadb38b42a4558034ef38917f843cdb22
Author: Lesheng Jin <[email protected]>
AuthorDate: Wed Sep 13 11:43:04 2023 -0700

    [Target][Device] Auto detect target and create device from str in torch 
style (#15714)
    
    - Target auto detection: `Target.auto_detect()`.
    - Target created from device: `Target.from_device("cuda")` or
      `Target.from_device(tvm.cuda())`
    - create device from str: `tvm.device("cuda:0")` or tvm.device("cuda",
      0)
---
 python/tvm/runtime/ndarray.py               |  19 ++++-
 python/tvm/target/detect_target.py          | 114 ++++++++++++++++++++++++++++
 python/tvm/target/target.py                 |  24 ++++++
 tests/python/unittest/test_device.py        |  71 +++++++++++++++++
 tests/python/unittest/test_target_target.py |  45 +++++++++++
 5 files changed, 271 insertions(+), 2 deletions(-)

diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index a78c68ee67..6f0d1f440a 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -324,11 +324,26 @@ def device(dev_type, dev_id=0):
       assert tvm.device("cpu", 1) == tvm.cpu(1)
       assert tvm.device("cuda", 0) == tvm.cuda(0)
     """
+    if not isinstance(dev_id, int):
+        raise ValueError(f"Invalid device id: {dev_id}")
+
     if isinstance(dev_type, string_types):
         dev_type = dev_type.split()[0]
+        if dev_type.count(":") == 0:
+            pass
+        elif dev_type.count(":") == 1:
+            # It will override the dev_id passed by the user.
+            dev_type, dev_id = dev_type.split(":")
+            if not dev_id.isdigit():
+                raise ValueError(f"Invalid device id: {dev_id}")
+            dev_id = int(dev_id)
+        else:
+            raise ValueError(f"Invalid device string: {dev_type}")
+
         if dev_type not in Device.STR2MASK:
-            raise ValueError(f"Unknown device type {dev_type}")
-        dev_type = Device.STR2MASK[dev_type]
+            raise ValueError(f"Unknown device type: {dev_type}")
+
+        return Device(Device.STR2MASK[dev_type], dev_id)
     return Device(dev_type, dev_id)
 
 
diff --git a/python/tvm/target/detect_target.py 
b/python/tvm/target/detect_target.py
new file mode 100644
index 0000000000..5c139cc949
--- /dev/null
+++ b/python/tvm/target/detect_target.py
@@ -0,0 +1,114 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Detect target."""
+from typing import Union
+
+from . import Target
+from .._ffi import get_global_func
+from .._ffi.runtime_ctypes import Device
+from ..runtime.ndarray import device
+
+
+def _detect_metal(dev: Device) -> Target:
+    return Target(
+        {
+            "kind": "metal",
+            "max_shared_memory_per_block": 32768,
+            "max_threads_per_block": dev.max_threads_per_block,
+            "thread_warp_size": dev.warp_size,
+        }
+    )
+
+
+def _detect_cuda(dev: Device) -> Target:
+    return Target(
+        {
+            "kind": "cuda",
+            "max_shared_memory_per_block": dev.max_shared_memory_per_block,
+            "max_threads_per_block": dev.max_threads_per_block,
+            "thread_warp_size": dev.warp_size,
+            "arch": "sm_" + dev.compute_version.replace(".", ""),
+        }
+    )
+
+
+def _detect_rocm(dev: Device) -> Target:
+    return Target(
+        {
+            "kind": "rocm",
+            "mtriple": "amdgcn-and-amdhsa-hcc",
+            "max_shared_memory_per_block": dev.max_shared_memory_per_block,
+            "max_threads_per_block": dev.max_threads_per_block,
+            "thread_warp_size": dev.warp_size,
+        }
+    )
+
+
+def _detect_vulkan(dev: Device) -> Target:
+    f_get_target_property = 
get_global_func("device_api.vulkan.get_target_property")
+    return Target(
+        {
+            "kind": "vulkan",
+            "max_threads_per_block": dev.max_threads_per_block,
+            "max_shared_memory_per_block": dev.max_shared_memory_per_block,
+            "thread_warp_size": dev.warp_size,
+            "supports_float16": f_get_target_property(dev, "supports_float16"),
+            "supports_int16": f_get_target_property(dev, "supports_int16"),
+            "supports_int8": f_get_target_property(dev, "supports_int8"),
+            "supports_16bit_buffer": f_get_target_property(dev, 
"supports_16bit_buffer"),
+        }
+    )
+
+
+def detect_target_from_device(dev: Union[str, Device]) -> Target:
+    """Detects Target associated with the given device. If the device does not 
exist,
+    there will be an Error.
+
+    Parameters
+    ----------
+    dev : Union[str, Device]
+        The device to detect the target for.
+        Supported device types: ["cuda", "metal", "rocm", "vulkan"]
+
+    Returns
+    -------
+    target : Target
+        The detected target.
+    """
+    if isinstance(dev, str):
+        dev = device(dev)
+    device_type = Device.MASK2STR[dev.device_type]
+    if device_type not in SUPPORT_DEVICE:
+        raise ValueError(
+            f"Auto detection for device `{device_type}` is not supported. "
+            f"Currently only supports: {SUPPORT_DEVICE.keys()}"
+        )
+    if not dev.exist:
+        raise ValueError(
+            f"Cannot detect device `{dev}`. Please make sure the device and 
its driver "
+            "is installed properly, and TVM is compiled with the driver"
+        )
+    target = SUPPORT_DEVICE[device_type](dev)
+    return target
+
+
+SUPPORT_DEVICE = {
+    "cuda": _detect_cuda,
+    "metal": _detect_metal,
+    "vulkan": _detect_vulkan,
+    "rocm": _detect_rocm,
+}
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index b027b99b17..ec74cbcdb6 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -18,9 +18,11 @@
 import json
 import re
 import warnings
+from typing import Union
 
 import tvm._ffi
 from tvm._ffi import register_func as _register_func
+from tvm._ffi.runtime_ctypes import Device
 from tvm.runtime import Object, convert
 from tvm.runtime.container import String
 from tvm.ir.container import Map, Array
@@ -148,6 +150,28 @@ class Target(Object):
     def with_host(self, host=None):
         return _ffi_api.WithHost(self, Target(host))
 
+    @staticmethod
+    def from_device(device: Union[str, Device]) -> "Target":
+        """Detects Target associated with the given device. If the device does 
not exist,
+        there will be an Error.
+
+        Parameters
+        ----------
+        dev : Union[str, Device]
+            The device to detect the target for.
+            Supported device types: ["cuda", "metal", "rocm", "vulkan", 
"opencl", "cpu"]
+
+        Returns
+        -------
+        target : Target
+            The detected target.
+        """
+        from .detect_target import (  # pylint: disable=import-outside-toplevel
+            detect_target_from_device,
+        )
+
+        return detect_target_from_device(device)
+
     @staticmethod
     def current(allow_none=True):
         """Returns the current target.
diff --git a/tests/python/unittest/test_device.py 
b/tests/python/unittest/test_device.py
new file mode 100644
index 0000000000..9d10251e15
--- /dev/null
+++ b/tests/python/unittest/test_device.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+import tvm
+import tvm.testing
+from tvm._ffi.runtime_ctypes import Device
+
+
[email protected](
+    "dev_str, expected_device_type, expect_device_id",
+    [
+        ("cpu", Device.kDLCPU, 0),
+        ("cuda", Device.kDLCUDA, 0),
+        ("cuda:0", Device.kDLCUDA, 0),
+        ("cuda:3", Device.kDLCUDA, 3),
+        ("metal:2", Device.kDLMetal, 2),
+    ],
+)
+def test_device(dev_str, expected_device_type, expect_device_id):
+    dev = tvm.device(dev_str)
+    assert dev.device_type == expected_device_type
+    assert dev.device_id == expect_device_id
+
+
[email protected](
+    "dev_type, dev_id, expected_device_type, expect_device_id",
+    [
+        ("cpu", 0, Device.kDLCPU, 0),
+        ("cuda", 0, Device.kDLCUDA, 0),
+        (Device.kDLCUDA, 0, Device.kDLCUDA, 0),
+        ("cuda", 3, Device.kDLCUDA, 3),
+        (Device.kDLMetal, 2, Device.kDLMetal, 2),
+    ],
+)
+def test_device_with_dev_id(dev_type, dev_id, expected_device_type, 
expect_device_id):
+    dev = tvm.device(dev_type=dev_type, dev_id=dev_id)
+    assert dev.device_type == expected_device_type
+    assert dev.device_id == expect_device_id
+
+
[email protected](
+    "dev_type, dev_id",
+    [
+        ("cpu:0:0", None),
+        ("cpu:?", None),
+        ("cpu:", None),
+        (Device.kDLCUDA, "?"),
+    ],
+)
+def test_deive_error(dev_type, dev_id):
+    with pytest.raises(ValueError):
+        dev = tvm.device(dev_type=dev_type, dev_id=dev_id)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/unittest/test_target_target.py 
b/tests/python/unittest/test_target_target.py
index 2b0f1b2dd7..da1bbc2c21 100644
--- a/tests/python/unittest/test_target_target.py
+++ b/tests/python/unittest/test_target_target.py
@@ -488,5 +488,50 @@ def test_target_features():
     assert not target_with_features.features.is_missing
 
 
[email protected]_cuda
[email protected]("input_device", ["cuda", tvm.cuda()])
+def test_target_from_device_cuda(input_device):
+    target = Target.from_device(input_device)
+
+    dev = tvm.cuda()
+    assert target.kind.name == "cuda"
+    assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block
+    assert target.max_shared_memory_per_block == 
dev.max_shared_memory_per_block
+    assert target.thread_warp_size == dev.warp_size
+    assert target.arch == "sm_" + dev.compute_version.replace(".", "")
+
+
[email protected]_rocm
[email protected]("input_device", ["rocm", tvm.rocm()])
+def test_target_from_device_rocm(input_device):
+    target = Target.from_device(input_device)
+
+    dev = tvm.rocm()
+    assert target.kind.name == "rocm"
+    assert target.attrs["mtriple"] == "amdgcn-and-amdhsa-hcc"
+    assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block
+    assert target.max_shared_memory_per_block == 
dev.max_shared_memory_per_block
+    assert target.thread_warp_size == dev.warp_size
+
+
[email protected]_vulkan
[email protected]("input_device", ["vulkan", tvm.vulkan()])
+def test_target_from_device_rocm(input_device):
+    target = Target.from_device(input_device)
+
+    f_get_target_property = 
tvm.get_global_func("device_api.vulkan.get_target_property")
+    dev = tvm.vulkan()
+    assert target.kind.name == "vulkan"
+    assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block
+    assert target.max_shared_memory_per_block == 
dev.max_shared_memory_per_block
+    assert target.thread_warp_size == dev.warp_size
+    assert target.attrs["supports_float16"] == f_get_target_property(dev, 
"supports_float16")
+    assert target.attrs["supports_int16"] == f_get_target_property(dev, 
"supports_int16")
+    assert target.attrs["supports_int8"] == f_get_target_property(dev, 
"supports_int8")
+    assert target.attrs["supports_16bit_buffer"] == f_get_target_property(
+        dev, "supports_16bit_buffer"
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to