This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 796e71a  Add Python representation for VirtualDevice (#9812)
796e71a is described below

commit 796e71aa7499c7264e3e11bb98600b2187d9cd4b
Author: Christopher Sidebottom <[email protected]>
AuthorDate: Tue Jan 4 15:49:07 2022 +0000

    Add Python representation for VirtualDevice (#9812)
    
    * Add Python representation for VirtualDevice
    
    This adds a Python class to represent the VirtualDevice so that the
    behaviour for `device_type()` can be semi-replicated.
    
    These tests were actually not being ran and were broken so I've added
    them to the integration script.
    
    * Update other references to make_virtual_device
---
 python/tvm/relay/op/annotation/annotation.py          |  4 ++--
 python/tvm/relay/op/tensor.py                         |  4 ++--
 python/tvm/target/__init__.py                         |  2 +-
 python/tvm/target/virtual_device.py                   | 18 +++++++++++++++---
 tests/python/relay/test_pass_dead_code_elimination.py |  2 +-
 tests/python/relay/test_pass_plan_devices.py          | 10 +++++-----
 tests/python/target/test_virtual_device.py            |  7 +++----
 tests/scripts/task_python_integration.sh              |  3 +++
 8 files changed, 32 insertions(+), 18 deletions(-)

diff --git a/python/tvm/relay/op/annotation/annotation.py 
b/python/tvm/relay/op/annotation/annotation.py
index f2ce6c5..5582ac9 100644
--- a/python/tvm/relay/op/annotation/annotation.py
+++ b/python/tvm/relay/op/annotation/annotation.py
@@ -25,9 +25,9 @@ from .. import op as reg
 
 def _make_virtual_device(device):
     if isinstance(device, _Device):
-        return target.make_virtual_device(device)
+        return target.VirtualDevice(device)
     if isinstance(device, str):
-        return target.make_virtual_device(_nd.device(device))
+        return target.VirtualDevice(_nd.device(device))
     raise ValueError("expecting a Device or device name, but received a %s" % 
(type(device)))
 
 
diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py
index 963bb3d..0c930dd 100644
--- a/python/tvm/relay/op/tensor.py
+++ b/python/tvm/relay/op/tensor.py
@@ -29,9 +29,9 @@ from . import op as reg
 
 def _make_virtual_device(device):
     if isinstance(device, _Device):
-        return target.make_virtual_device(device)
+        return target.VirtualDevice(device)
     if isinstance(device, str):
-        return target.make_virtual_device(_nd.device(device))
+        return target.VirtualDevice(_nd.device(device))
     raise ValueError("expecting a Device or device name, but received a %s" % 
(type(device)))
 
 
diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py
index 6c13ced..cd667ce 100644
--- a/python/tvm/target/__init__.py
+++ b/python/tvm/target/__init__.py
@@ -71,7 +71,7 @@ from .target import (
     riscv_cpu,
     hexagon,
 )
-from .virtual_device import make_virtual_device
+from .virtual_device import VirtualDevice
 from .compilation_config import make_compilation_config
 from .tag import list_tags
 from .generic_func import GenericFunc
diff --git a/python/tvm/target/virtual_device.py 
b/python/tvm/target/virtual_device.py
index a88d405..9ab864c 100644
--- a/python/tvm/target/virtual_device.py
+++ b/python/tvm/target/virtual_device.py
@@ -15,11 +15,23 @@
 # specific language governing permissions and limitations
 # under the License.
 """Python bindings for creating VirtualDevices."""
+
+import tvm
+from tvm.runtime import Object
+
 from . import _ffi_api
 
 
-# TODO(mbs): We need an official Python class representation given the 
importance of this structure.
+@tvm._ffi.register_object
+class VirtualDevice(Object):
+    """A compile time representation for where data is to be stored at runtime,
+    and how to compile code to compute it."""
 
+    def __init__(self, device, target=None, memory_scope="") -> None:
+        self.__init_handle_by_constructor__(
+            _ffi_api.VirtualDevice_ForDeviceTargetAndMemoryScope, device, 
target, memory_scope
+        )
 
-def make_virtual_device(device, target=None, memory_scope=""):
-    return _ffi_api.VirtualDevice_ForDeviceTargetAndMemoryScope(device, 
target, memory_scope)
+    @property
+    def device_type(self) -> int:
+        return self.device_type_int
diff --git a/tests/python/relay/test_pass_dead_code_elimination.py 
b/tests/python/relay/test_pass_dead_code_elimination.py
index bc19bcd..bcbbfaa 100644
--- a/tests/python/relay/test_pass_dead_code_elimination.py
+++ b/tests/python/relay/test_pass_dead_code_elimination.py
@@ -19,7 +19,7 @@ from tvm.relay import Function, transform
 from tvm.relay.testing import inception_v3
 import pytest
 
-cpu_scope = tvm.target.make_virtual_device(tvm.cpu(), 
tvm.target.Target("llvm"))
+cpu_scope = tvm.target.VirtualDevice(tvm.cpu(), tvm.target.Target("llvm"))
 metatable = {"VirtualDevice": [cpu_scope]}
 core = tvm.IRModule()
 core.import_from_std("core.rly")
diff --git a/tests/python/relay/test_pass_plan_devices.py 
b/tests/python/relay/test_pass_plan_devices.py
index 82e40af..6232f9c 100644
--- a/tests/python/relay/test_pass_plan_devices.py
+++ b/tests/python/relay/test_pass_plan_devices.py
@@ -41,13 +41,13 @@ TARGETS = {
     tvm.tir.IntImm("int32", GPU_DEVICE.device_type): GPU_TARGET,
 }
 
-HOST = tvm.target.make_virtual_device(HOST_DEVICE, HOST_TARGET)  # 
device_type=1
-CPU = tvm.target.make_virtual_device(CPU_DEVICE, CPU_TARGET)  # device_type=1
-GPU = tvm.target.make_virtual_device(GPU_DEVICE, GPU_TARGET)  # device_type=2
+HOST = tvm.target.VirtualDevice(HOST_DEVICE, HOST_TARGET)  # device_type=1
+CPU = tvm.target.VirtualDevice(CPU_DEVICE, CPU_TARGET)  # device_type=1
+GPU = tvm.target.VirtualDevice(GPU_DEVICE, GPU_TARGET)  # device_type=2
 DEFAULT = GPU
 
-CPU_SCOPE_A = tvm.target.make_virtual_device(CPU_DEVICE, CPU_TARGET, 
memory_scope="scopeA")
-CPU_SCOPE_B = tvm.target.make_virtual_device(CPU_DEVICE, CPU_TARGET, 
memory_scope="scopeB")
+CPU_SCOPE_A = tvm.target.VirtualDevice(CPU_DEVICE, CPU_TARGET, 
memory_scope="scopeA")
+CPU_SCOPE_B = tvm.target.VirtualDevice(CPU_DEVICE, CPU_TARGET, 
memory_scope="scopeB")
 
 CTXT = tvm.transform.PassContext(config={"relay.fallback_device_type": 
DEFAULT.device_type_int})
 
diff --git a/tests/python/target/test_virtual_device.py 
b/tests/python/target/test_virtual_device.py
index eec77bc..392e418 100644
--- a/tests/python/target/test_virtual_device.py
+++ b/tests/python/target/test_virtual_device.py
@@ -14,13 +14,12 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import numpy as np
 import pytest
 import tvm
 
 
 def test_make_virtual_device_for_device():
-    virtual_device = tvm.target.make_virtual_device(tvm.device("cuda"))
+    virtual_device = tvm.target.VirtualDevice(tvm.device("cuda"))
     assert virtual_device.device_type == 2
     # ie kDLCUDA
     assert virtual_device.virtual_device_id == 0
@@ -30,7 +29,7 @@ def test_make_virtual_device_for_device():
 
 def test_make_virtual_device_for_device_and_target():
     target = tvm.target.Target("cuda")
-    virtual_device = tvm.target.make_virtual_device(tvm.device("cuda"), target)
+    virtual_device = tvm.target.VirtualDevice(tvm.device("cuda"), target)
     assert virtual_device.device_type == 2  # ie kDLCUDA
     assert virtual_device.target == target
     assert virtual_device.memory_scope == ""
@@ -39,7 +38,7 @@ def test_make_virtual_device_for_device_and_target():
 def test_make_virtual_device_for_device_target_and_memory_scope():
     target = tvm.target.Target("cuda")
     scope = "local"
-    virtual_device = tvm.target.make_virtual_device(tvm.device("cuda"), 
target, scope)
+    virtual_device = tvm.target.VirtualDevice(tvm.device("cuda"), target, 
scope)
     assert virtual_device.device_type == 2  # ie kDLCUDA
     assert virtual_device.target == target
     assert virtual_device.memory_scope == scope
diff --git a/tests/scripts/task_python_integration.sh 
b/tests/scripts/task_python_integration.sh
index 55f5b96..4992bfa 100755
--- a/tests/scripts/task_python_integration.sh
+++ b/tests/scripts/task_python_integration.sh
@@ -72,6 +72,9 @@ TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm;cuda}" \
 # Command line driver test
 run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-driver tests/python/driver
 
+# Target test
+run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-target tests/python/target
+
 # Do not enable OpenGL
 # run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-webgl tests/webgl
 

Reply via email to