This is an automated email from the ASF dual-hosted git repository.
leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 796e71a Add Python representation for VirtualDevice (#9812)
796e71a is described below
commit 796e71aa7499c7264e3e11bb98600b2187d9cd4b
Author: Christopher Sidebottom <[email protected]>
AuthorDate: Tue Jan 4 15:49:07 2022 +0000
Add Python representation for VirtualDevice (#9812)
* Add Python representation for VirtualDevice
This adds a Python class to represent the VirtualDevice so that the
behaviour for `device_type()` can be semi-replicated.
These tests were actually not being ran and were broken so I've added
them to the integration script.
* Update other references to make_virtual_device
---
python/tvm/relay/op/annotation/annotation.py | 4 ++--
python/tvm/relay/op/tensor.py | 4 ++--
python/tvm/target/__init__.py | 2 +-
python/tvm/target/virtual_device.py | 18 +++++++++++++++---
tests/python/relay/test_pass_dead_code_elimination.py | 2 +-
tests/python/relay/test_pass_plan_devices.py | 10 +++++-----
tests/python/target/test_virtual_device.py | 7 +++----
tests/scripts/task_python_integration.sh | 3 +++
8 files changed, 32 insertions(+), 18 deletions(-)
diff --git a/python/tvm/relay/op/annotation/annotation.py
b/python/tvm/relay/op/annotation/annotation.py
index f2ce6c5..5582ac9 100644
--- a/python/tvm/relay/op/annotation/annotation.py
+++ b/python/tvm/relay/op/annotation/annotation.py
@@ -25,9 +25,9 @@ from .. import op as reg
def _make_virtual_device(device):
if isinstance(device, _Device):
- return target.make_virtual_device(device)
+ return target.VirtualDevice(device)
if isinstance(device, str):
- return target.make_virtual_device(_nd.device(device))
+ return target.VirtualDevice(_nd.device(device))
raise ValueError("expecting a Device or device name, but received a %s" %
(type(device)))
diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py
index 963bb3d..0c930dd 100644
--- a/python/tvm/relay/op/tensor.py
+++ b/python/tvm/relay/op/tensor.py
@@ -29,9 +29,9 @@ from . import op as reg
def _make_virtual_device(device):
if isinstance(device, _Device):
- return target.make_virtual_device(device)
+ return target.VirtualDevice(device)
if isinstance(device, str):
- return target.make_virtual_device(_nd.device(device))
+ return target.VirtualDevice(_nd.device(device))
raise ValueError("expecting a Device or device name, but received a %s" %
(type(device)))
diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py
index 6c13ced..cd667ce 100644
--- a/python/tvm/target/__init__.py
+++ b/python/tvm/target/__init__.py
@@ -71,7 +71,7 @@ from .target import (
riscv_cpu,
hexagon,
)
-from .virtual_device import make_virtual_device
+from .virtual_device import VirtualDevice
from .compilation_config import make_compilation_config
from .tag import list_tags
from .generic_func import GenericFunc
diff --git a/python/tvm/target/virtual_device.py
b/python/tvm/target/virtual_device.py
index a88d405..9ab864c 100644
--- a/python/tvm/target/virtual_device.py
+++ b/python/tvm/target/virtual_device.py
@@ -15,11 +15,23 @@
# specific language governing permissions and limitations
# under the License.
"""Python bindings for creating VirtualDevices."""
+
+import tvm
+from tvm.runtime import Object
+
from . import _ffi_api
-# TODO(mbs): We need an official Python class representation given the
importance of this structure.
+@tvm._ffi.register_object
+class VirtualDevice(Object):
+ """A compile time representation for where data is to be stored at runtime,
+ and how to compile code to compute it."""
+ def __init__(self, device, target=None, memory_scope="") -> None:
+ self.__init_handle_by_constructor__(
+ _ffi_api.VirtualDevice_ForDeviceTargetAndMemoryScope, device,
target, memory_scope
+ )
-def make_virtual_device(device, target=None, memory_scope=""):
- return _ffi_api.VirtualDevice_ForDeviceTargetAndMemoryScope(device,
target, memory_scope)
+ @property
+ def device_type(self) -> int:
+ return self.device_type_int
diff --git a/tests/python/relay/test_pass_dead_code_elimination.py
b/tests/python/relay/test_pass_dead_code_elimination.py
index bc19bcd..bcbbfaa 100644
--- a/tests/python/relay/test_pass_dead_code_elimination.py
+++ b/tests/python/relay/test_pass_dead_code_elimination.py
@@ -19,7 +19,7 @@ from tvm.relay import Function, transform
from tvm.relay.testing import inception_v3
import pytest
-cpu_scope = tvm.target.make_virtual_device(tvm.cpu(),
tvm.target.Target("llvm"))
+cpu_scope = tvm.target.VirtualDevice(tvm.cpu(), tvm.target.Target("llvm"))
metatable = {"VirtualDevice": [cpu_scope]}
core = tvm.IRModule()
core.import_from_std("core.rly")
diff --git a/tests/python/relay/test_pass_plan_devices.py
b/tests/python/relay/test_pass_plan_devices.py
index 82e40af..6232f9c 100644
--- a/tests/python/relay/test_pass_plan_devices.py
+++ b/tests/python/relay/test_pass_plan_devices.py
@@ -41,13 +41,13 @@ TARGETS = {
tvm.tir.IntImm("int32", GPU_DEVICE.device_type): GPU_TARGET,
}
-HOST = tvm.target.make_virtual_device(HOST_DEVICE, HOST_TARGET) #
device_type=1
-CPU = tvm.target.make_virtual_device(CPU_DEVICE, CPU_TARGET) # device_type=1
-GPU = tvm.target.make_virtual_device(GPU_DEVICE, GPU_TARGET) # device_type=2
+HOST = tvm.target.VirtualDevice(HOST_DEVICE, HOST_TARGET) # device_type=1
+CPU = tvm.target.VirtualDevice(CPU_DEVICE, CPU_TARGET) # device_type=1
+GPU = tvm.target.VirtualDevice(GPU_DEVICE, GPU_TARGET) # device_type=2
DEFAULT = GPU
-CPU_SCOPE_A = tvm.target.make_virtual_device(CPU_DEVICE, CPU_TARGET,
memory_scope="scopeA")
-CPU_SCOPE_B = tvm.target.make_virtual_device(CPU_DEVICE, CPU_TARGET,
memory_scope="scopeB")
+CPU_SCOPE_A = tvm.target.VirtualDevice(CPU_DEVICE, CPU_TARGET,
memory_scope="scopeA")
+CPU_SCOPE_B = tvm.target.VirtualDevice(CPU_DEVICE, CPU_TARGET,
memory_scope="scopeB")
CTXT = tvm.transform.PassContext(config={"relay.fallback_device_type":
DEFAULT.device_type_int})
diff --git a/tests/python/target/test_virtual_device.py
b/tests/python/target/test_virtual_device.py
index eec77bc..392e418 100644
--- a/tests/python/target/test_virtual_device.py
+++ b/tests/python/target/test_virtual_device.py
@@ -14,13 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import numpy as np
import pytest
import tvm
def test_make_virtual_device_for_device():
- virtual_device = tvm.target.make_virtual_device(tvm.device("cuda"))
+ virtual_device = tvm.target.VirtualDevice(tvm.device("cuda"))
assert virtual_device.device_type == 2
# ie kDLCUDA
assert virtual_device.virtual_device_id == 0
@@ -30,7 +29,7 @@ def test_make_virtual_device_for_device():
def test_make_virtual_device_for_device_and_target():
target = tvm.target.Target("cuda")
- virtual_device = tvm.target.make_virtual_device(tvm.device("cuda"), target)
+ virtual_device = tvm.target.VirtualDevice(tvm.device("cuda"), target)
assert virtual_device.device_type == 2 # ie kDLCUDA
assert virtual_device.target == target
assert virtual_device.memory_scope == ""
@@ -39,7 +38,7 @@ def test_make_virtual_device_for_device_and_target():
def test_make_virtual_device_for_device_target_and_memory_scope():
target = tvm.target.Target("cuda")
scope = "local"
- virtual_device = tvm.target.make_virtual_device(tvm.device("cuda"),
target, scope)
+ virtual_device = tvm.target.VirtualDevice(tvm.device("cuda"), target,
scope)
assert virtual_device.device_type == 2 # ie kDLCUDA
assert virtual_device.target == target
assert virtual_device.memory_scope == scope
diff --git a/tests/scripts/task_python_integration.sh
b/tests/scripts/task_python_integration.sh
index 55f5b96..4992bfa 100755
--- a/tests/scripts/task_python_integration.sh
+++ b/tests/scripts/task_python_integration.sh
@@ -72,6 +72,9 @@ TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm;cuda}" \
# Command line driver test
run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-driver tests/python/driver
+# Target test
+run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-target tests/python/target
+
# Do not enable OpenGL
# run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-webgl tests/webgl