This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 2226a1f558 [Unity] Multi-device support for Relax (#15447)
2226a1f558 is described below
commit 2226a1f5581cb77e599ce16d61227ed1088ff058
Author: Yong Wu <[email protected]>
AuthorDate: Fri Aug 11 12:25:29 2023 -0700
[Unity] Multi-device support for Relax (#15447)
---
include/tvm/ir/global_info.h | 50 ++++++++++-
include/tvm/relax/attrs/op.h | 8 ++
include/tvm/relax/struct_info.h | 17 +++-
include/tvm/relay/transform.h | 12 +++
include/tvm/script/printer/ir_docsifier.h | 7 ++
include/tvm/target/target.h | 10 ---
include/tvm/target/target_kind.h | 27 +-----
include/tvm/tir/usmp/utils.h | 1 +
python/tvm/ir/__init__.py | 2 +-
python/tvm/ir/global_info.py | 12 +++
python/tvm/ir/json_compact.py | 17 ++++
python/tvm/relax/op/base.py | 21 +++++
python/tvm/relax/struct_info.py | 14 ++--
python/tvm/script/ir_builder/ir/__init__.py | 2 +
python/tvm/script/ir_builder/ir/ir.py | 38 ++++++++-
python/tvm/script/ir_builder/relax/ir.py | 2 +
python/tvm/script/parser/ir/__init__.py | 2 +
python/tvm/script/parser/relax/dist.py | 8 +-
python/tvm/script/parser/relax/entry.py | 21 ++++-
src/driver/driver_api.cc | 17 ++++
src/ir/global_info.cc | 13 +++
src/relax/analysis/struct_info_analysis.cc | 25 ++++--
src/relax/backend/vm/vm_builtin_lower.cc | 21 +++++
src/relax/ir/expr.cc | 2 +-
src/relax/ir/struct_info.cc | 12 +--
src/relax/ir/struct_info_functor.cc | 7 +-
src/relax/op/op.cc | 28 +++++++
src/relax/transform/convert_layout.cc | 6 +-
src/relax/transform/to_mixed_precision.cc | 6 +-
src/relay/backend/contrib/cmsisnn/target.cc | 4 +-
src/relay/backend/contrib/codegen_c/target.cc | 2 +-
src/relay/backend/contrib/cutlass/target.cc | 2 +-
src/relay/backend/contrib/ethosu/codegen.cc | 4 +-
.../backend/contrib/example_target_hooks/target.cc | 6 +-
src/relay/backend/contrib/tensorrt/target.cc | 2 +-
src/relay/backend/contrib/uma/targets.cc | 28 +++----
src/runtime/relax_vm/builtin.cc | 6 ++
src/script/ir_builder/ir/ir.cc | 30 +++++++
src/script/printer/ir/ir.cc | 10 +++
src/script/printer/ir_docsifier.cc | 6 ++
src/script/printer/relax/struct_info.cc | 7 ++
src/script/printer/relax/utils.h | 16 ++++
src/target/codegen.cc | 3 +
src/target/target.cc | 15 +---
tests/cpp/target_test.cc | 2 +-
.../python/relax/test_json_compact.py | 50 ++++++-----
tests/python/relax/test_relax_operators.py | 30 +++++++
tests/python/relax/test_tvmscript_parser.py | 97 +++++++++++++++++++++-
.../relax/test_tvmscript_parser_op_manipulate.py | 15 ++++
tests/python/relax/test_vm_build.py | 35 ++++++++
tests/python/relax/test_vm_codegen_only.py | 27 ++++++
51 files changed, 680 insertions(+), 125 deletions(-)
diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h
index bd6006864a..57bc57fd09 100644
--- a/include/tvm/ir/global_info.h
+++ b/include/tvm/ir/global_info.h
@@ -25,10 +25,16 @@
#ifndef TVM_IR_GLOBAL_INFO_H_
#define TVM_IR_GLOBAL_INFO_H_
-#include "tvm/ir/expr.h"
+#include <tvm/ir/expr.h>
+#include <tvm/target/target.h>
namespace tvm {
+/*!
+ * \brief Abstract label for an area of memory.
+ */
+using MemoryScope = String;
+
/*!
* \brief GlobalInfo are globally static object that are referred by the IR
itself.
* Base node for all global info that can appear in the IR
@@ -50,6 +56,48 @@ class GlobalInfo : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(GlobalInfo, ObjectRef, GlobalInfoNode);
};
+/*!
+ * \brief A global info subclass for virtual devices.
+ */
+class VDeviceNode : public GlobalInfoNode {
+ public:
+ /*! \brief The \p Target describing how to compile for the virtual device. */
+ Target target;
+ /*! \brief The device identifier for the virtual device. This enables us to
+ * differentiate between distinct devices with same Target, such as multiple
GPUs.
+ */
+ int vdevice_id;
+ MemoryScope memory_scope;
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("target", &target);
+ v->Visit("vdevice_id", &vdevice_id);
+ v->Visit("memory_scope", &memory_scope);
+ }
+
+ TVM_DLL bool SEqualReduce(const VDeviceNode* other, SEqualReducer equal)
const {
+ return equal(target, other->target) && equal(vdevice_id,
other->vdevice_id) &&
+ equal(memory_scope, other->memory_scope);
+ }
+
+ TVM_DLL void SHashReduce(SHashReducer hash_reduce) const {
+ hash_reduce(target);
+ hash_reduce(vdevice_id);
+ hash_reduce(memory_scope);
+ }
+ static constexpr const char* _type_key = "VDevice";
+ TVM_DECLARE_FINAL_OBJECT_INFO(VDeviceNode, GlobalInfoNode);
+};
+
+/*!
+ * \brief Managed reference to VDeviceNode.
+ * \sa VDeviceNode
+ */
+class VDevice : public GlobalInfo {
+ public:
+ TVM_DLL explicit VDevice(Target tgt, int dev_id, MemoryScope mem_scope);
+ TVM_DEFINE_OBJECT_REF_METHODS(VDevice, GlobalInfo, VDeviceNode);
+};
+
/*!
* \brief A dummy global info sub-class for testing purpose.
*/
diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h
index 8c0ee6abc6..f33f3a9701 100644
--- a/include/tvm/relax/attrs/op.h
+++ b/include/tvm/relax/attrs/op.h
@@ -57,6 +57,14 @@ struct CallTIRInplaceAttrs : public
tvm::AttrsNode<CallTIRInplaceAttrs> {
}
}; // struct CallTIRInplaceAttrs
+/*! \brief Attributes used in to_vdevice */
+struct ToVDeviceAttrs : public tvm::AttrsNode<ToVDeviceAttrs> {
+ VDevice dst_vdevice;
+ TVM_DECLARE_ATTRS(ToVDeviceAttrs, "relax.attrs.ToVDeviceAttrs") {
+ TVM_ATTR_FIELD(dst_vdevice).describe("The destination device where the
data is copied to.");
+ }
+}; // struct ToVDeviceAttrs
+
} // namespace relax
} // namespace tvm
diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h
index 385c320db1..d2bf525225 100644
--- a/include/tvm/relax/struct_info.h
+++ b/include/tvm/relax/struct_info.h
@@ -156,6 +156,10 @@ class TensorStructInfoNode : public StructInfoNode {
* \note shape must be normalized: it can only be NullOpt or ShapeExpr or
Var.
*/
Optional<Expr> shape;
+ /*! \brief The virtual device, indicates where the tensor
+ * is expected to be executed.
+ */
+ Optional<VDevice> vdevice;
/*! \brief The content data type, use void to denote the dtype is unknown. */
DataType dtype;
/*!
@@ -180,17 +184,20 @@ class TensorStructInfoNode : public StructInfoNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
+ v->Visit("vdevice", &vdevice);
v->Visit("ndim", &ndim);
v->Visit("span", &span);
}
bool SEqualReduce(const TensorStructInfoNode* other, SEqualReducer equal)
const {
- return equal(shape, other->shape) && equal(ndim, other->ndim) &&
equal(dtype, other->dtype);
+ return equal(shape, other->shape) && equal(ndim, other->ndim) &&
+ equal(vdevice, other->vdevice) && equal(dtype, other->dtype);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(shape);
hash_reduce(dtype);
+ hash_reduce(vdevice);
hash_reduce(ndim);
}
@@ -208,19 +215,23 @@ class TensorStructInfo : public StructInfo {
* \brief Construction with a known shape expression.
* \param shape The shape of the tensor.
* \param dtype The data type of tensor's elements.
+ * \param vdevice The virtual device.
* \param span The span of the AST.
*
* \note shape must already be normalized.
*/
- TVM_DLL TensorStructInfo(Expr shape, DataType dtype, Span span = Span());
+ TVM_DLL TensorStructInfo(Expr shape, DataType dtype, VDevice vdevice =
VDevice(),
+ Span span = Span());
/*!
* \brief Construction with an unknown shape expression.
* \param dtype The data type of tensor's elements.
* \param ndim The number of dimensions
+ * \param vdevice The virtual device.
* \param span The span of the AST.
*/
- TVM_DLL TensorStructInfo(DataType dtype, int ndim, Span span = Span());
+ TVM_DLL TensorStructInfo(DataType dtype, int ndim, VDevice vdevice =
VDevice(),
+ Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(TensorStructInfo, StructInfo,
TensorStructInfoNode);
};
diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
index 7a0e003038..f4286512e5 100644
--- a/include/tvm/relay/transform.h
+++ b/include/tvm/relay/transform.h
@@ -47,6 +47,18 @@ using PassInfoNode = tvm::transform::PassInfoNode;
using PassContext = tvm::transform::PassContext;
using PassContextNode = tvm::transform::PassContextNode;
using Sequential = tvm::transform::Sequential;
+using FTVMRelayToTIR = tvm::transform::Pass;
+/*!
+ * \brief TIRToRuntime conversion specific to a TargetKind
+ *
+ * This function is responsible for scanning an IRModule for appropriate
Target-specific functions
+ and generating a Runtime module representing the compiled output
+ *
+ * \param ir_module Unified IRModule
+ * \param target Target to filter on or retrieve arguments from
+ * \return Runtime Module containing compiled functions
+ */
+using FTVMTIRToRuntime =
tvm::runtime::TypedPackedFunc<runtime::Module(IRModule, Target)>;
/*
* \brief Create a function pass.
diff --git a/include/tvm/script/printer/ir_docsifier.h
b/include/tvm/script/printer/ir_docsifier.h
index 156daebf00..1163464738 100644
--- a/include/tvm/script/printer/ir_docsifier.h
+++ b/include/tvm/script/printer/ir_docsifier.h
@@ -145,6 +145,8 @@ class IRDocsifierNode : public Object {
std::unordered_map<ObjectRef, VariableInfo, ObjectPtrHash, ObjectPtrEqual>
obj2info;
/*! \brief Metadata printing */
std::unordered_map<String, Array<ObjectRef>> metadata;
+ /*! \brief GlobalInfo printing */
+ std::unordered_map<String, Array<GlobalInfo>> global_infos;
/*! \brief The variable names used already */
std::unordered_set<String> defined_names;
/*! \brief Common prefixes of variable usages */
@@ -206,6 +208,11 @@ class IRDocsifierNode : public Object {
Optional<ExprDoc> GetVarDoc(const ObjectRef& obj) const;
/*! \brief Add a TVM object to the metadata section*/
ExprDoc AddMetadata(const ObjectRef& obj);
+ /*! \brief Add a GlobalInfo to the global_infos map.
+ * \param name The name of key of global_infos.
+ * \param ginfo The GlobalInfo to be added.
+ */
+ void AddGlobalInfo(const String& name, const GlobalInfo& ginfo);
/*!
* \brief Check if a variable exists in the table.
* \param obj The variable object.
diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h
index 56d6a596b9..d47ac94e06 100644
--- a/include/tvm/target/target.h
+++ b/include/tvm/target/target.h
@@ -25,7 +25,6 @@
#define TVM_TARGET_TARGET_H_
#include <tvm/ir/expr.h>
-#include <tvm/ir/module.h>
#include <tvm/node/node.h>
#include <tvm/support/with.h>
#include <tvm/target/target_kind.h>
@@ -284,14 +283,5 @@ class Target : public ObjectRef {
*/
void CheckAndUpdateHostConsistency(Target* target, Target* host);
-/*!
- * \brief Check and update host field of the given legacy heterogeneous
targets and
- * target host.Note that this function is for legacy target api compatibility
issue only,
- * not recommended for other use.
- * \param ir_modules The pointer to a Map objects with keys being Target
objects
- * \param host The Target typed object for target host to be updated
- */
-void CheckAndUpdateHostConsistency(Map<Target, IRModule>* ir_modules, Target*
host);
-
} // namespace tvm
#endif // TVM_TARGET_TARGET_H_
diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h
index 19bcce3116..10808fd12d 100644
--- a/include/tvm/target/target_kind.h
+++ b/include/tvm/target/target_kind.h
@@ -24,7 +24,6 @@
#ifndef TVM_TARGET_TARGET_KIND_H_
#define TVM_TARGET_TARGET_KIND_H_
-#include <tvm/ir/transform.h>
#include <tvm/node/attr_registry_map.h>
#include <tvm/node/node.h>
@@ -50,31 +49,7 @@ using TargetFeatures = Map<String, ObjectRef>;
* \return The transformed Target JSON object.
*/
using TargetJSON = Map<String, ObjectRef>;
-using FTVMTargetParser = TypedPackedFunc<TargetJSON(TargetJSON)>;
-
-/*!
- * \brief RelayToTIR tvm::transform::Pass specific to a TargetKind
- *
- * Called before the default lowering passes.
- *
- * \param mod The module that an optimization pass runs on.
- * \param pass_ctx The pass context that can provide information for the
optimization.
- *
- * \return The transformed module.
- */
-using FTVMRelayToTIR = transform::Pass;
-
-/*!
- * \brief TIRToRuntime conversion specific to a TargetKind
- *
- * This function is responsible for scanning an IRModule for appropriate
Target-specific functions
- and generating a Runtime module representing the compiled output
- *
- * \param ir_module Unified IRModule
- * \param target Target to filter on or retrieve arguments from
- * \return Runtime Module containing compiled functions
- */
-using FTVMTIRToRuntime = runtime::TypedPackedFunc<runtime::Module(IRModule,
Target)>;
+using FTVMTargetParser = tvm::runtime::TypedPackedFunc<TargetJSON(TargetJSON)>;
namespace detail {
template <typename, typename, typename>
diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h
index f49e9ceef7..a67350a2bb 100644
--- a/include/tvm/tir/usmp/utils.h
+++ b/include/tvm/tir/usmp/utils.h
@@ -27,6 +27,7 @@
#include <tvm/ir/expr.h>
#include <tvm/ir/memory_pools.h>
+#include <tvm/ir/module.h>
#include <tvm/runtime/device_api.h>
#include <tvm/target/target.h>
#include <tvm/tir/stmt.h>
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index 49c2cf6348..939a5f6383 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -35,7 +35,7 @@ from .base import (
from .container import Array, Map
from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelayExpr
from .function import BaseFunc, CallingConv
-from .global_info import GlobalInfo, DummyGlobalInfo
+from .global_info import GlobalInfo, DummyGlobalInfo, VDevice
from .memory_pools import (
ConstantMemoryPools,
ConstantPoolInfo,
diff --git a/python/tvm/ir/global_info.py b/python/tvm/ir/global_info.py
index 17011e76a6..458a16717b 100644
--- a/python/tvm/ir/global_info.py
+++ b/python/tvm/ir/global_info.py
@@ -40,3 +40,15 @@ class DummyGlobalInfo(GlobalInfo):
self.__init_handle_by_constructor__(
_ffi_api.DummyGlobalInfo,
)
+
+
+class VDevice(GlobalInfo):
+ def __init__(
+ self,
+ target=None,
+ vdevice_id: int = 0,
+ memory_scope: str = "global",
+ ) -> None:
+ if isinstance(target, (dict, str)):
+ target = tvm.target.Target(tvm.runtime.convert(target))
+ self.__init_handle_by_constructor__(_ffi_api.VDevice, target,
vdevice_id, memory_scope)
diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py
index 6ce2a8b9e2..224932b00c 100644
--- a/python/tvm/ir/json_compact.py
+++ b/python/tvm/ir/json_compact.py
@@ -57,6 +57,21 @@ def create_updater(node_map, from_ver, to_ver):
return _updater
+def create_updater_13_to_14():
+ """Create an update to upgrade json from v0.13 to v0.14 for TVM Unity"""
+
+ def _update_vdevice(item, _):
+ if "vdevice" not in item["attrs"]:
+ item["attrs"]["vdevice"] = "0"
+ return item
+
+ node_map = {
+ "relax.TensorStructInfo": _update_vdevice,
+ }
+
+ return create_updater(node_map, "0.13", "0.14")
+
+
def create_updater_08_to_09():
"""
Create an update to upgrade json from v0.8 to v0.9
@@ -259,6 +274,8 @@ def upgrade_json(json_str):
data = create_updater_08_to_09()(create_updater_07_to_08()(data))
elif from_version.startswith("0.8"):
data = create_updater_08_to_09()(data)
+ elif from_version.startswith("0.13"):
+ data = create_updater_13_to_14()(data)
else:
raise ValueError(f"Cannot update from version {from_version}")
return json.dumps(data, indent=2)
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index 25c70e0493..1d49c00ea8 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -690,3 +690,24 @@ def invoke_pure_closure(
sinfo_args = [sinfo_args]
return _ffi_api.invoke_pure_closure(closure, args, sinfo_args) # type:
ignore
+
+
+def to_vdevice(data, dst_vdevice) -> Expr:
+ """Copy data to the destination device. This
+ operator helps data transferring between difference devices for
+ heterogeneous execution.
+
+ Parameters
+ ----------
+ data : Expr
+ The tensor to be copied.
+
+ dst_device : Union[:py:class:`Device`, str]
+ The destination device where the data is copied to.
+
+ Returns
+ -------
+ result : Expr
+ The copied result.
+ """
+ return _ffi_api.to_vdevice(data, dst_vdevice) # type: ignore
diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py
index 3dcc3dc9a0..e78e1cf69a 100644
--- a/python/tvm/relax/struct_info.py
+++ b/python/tvm/relax/struct_info.py
@@ -16,14 +16,14 @@
# under the License.
# pylint: disable=invalid-name, unused-import
"""The struct info nodes of the Relax language."""
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Union
import tvm._ffi
import tvm
-from tvm.ir import Span, Node, EnvFunc, Array, Type
+from tvm.ir import Span, EnvFunc, Array, VDevice
from tvm.tir import PrimExpr
-from .expr import StructInfo, Var, Expr, ShapeExpr
+from .expr import StructInfo, Expr, ShapeExpr
from . import _ffi_api, ty, expr
@@ -93,6 +93,9 @@ class TensorStructInfo(StructInfo):
dtype : Optional[str]
The content data type.
+ vdevice : Optional[Vdevice]
+ The virtual device.
+
ndim : Optional[int]
The number of dimensions of the tensor.
@@ -103,6 +106,7 @@ class TensorStructInfo(StructInfo):
shape: Optional[Expr]
dtype: str
+ vdevice: Optional[VDevice]
ndim: int
span: Span
@@ -110,14 +114,14 @@ class TensorStructInfo(StructInfo):
self,
shape: Union[Optional[Expr], List[PrimExpr]] = None,
dtype: str = "float32",
+ vdevice: Union[Optional[VDevice], str] = None,
ndim: int = -1,
span: Span = None,
) -> None:
if isinstance(shape, (list, tuple, Array)):
shape = ShapeExpr(shape)
-
self.__init_handle_by_constructor__(
- _ffi_api.TensorStructInfo, shape, dtype, ndim, span # type: ignore
+ _ffi_api.TensorStructInfo, shape, dtype, ndim, vdevice, span #
type: ignore
)
diff --git a/python/tvm/script/ir_builder/ir/__init__.py
b/python/tvm/script/ir_builder/ir/__init__.py
index 68eda2cfee..fdf44b2b79 100644
--- a/python/tvm/script/ir_builder/ir/__init__.py
+++ b/python/tvm/script/ir_builder/ir/__init__.py
@@ -22,5 +22,7 @@ from .ir import (
ir_module,
module_attrs,
module_global_infos,
+ lookup_vdevice,
+ vdevice,
dummy_global_info,
)
diff --git a/python/tvm/script/ir_builder/ir/ir.py
b/python/tvm/script/ir_builder/ir/ir.py
index 53c48b4cc5..0d3523ec7d 100644
--- a/python/tvm/script/ir_builder/ir/ir.py
+++ b/python/tvm/script/ir_builder/ir/ir.py
@@ -18,7 +18,7 @@
from typing import Dict, List
-from tvm.ir import BaseFunc, GlobalVar, GlobalInfo, DummyGlobalInfo
+from tvm.ir import BaseFunc, GlobalVar, GlobalInfo, VDevice, DummyGlobalInfo
from tvm.runtime import Object as tvm_Object
@@ -104,3 +104,39 @@ def dummy_global_info() -> DummyGlobalInfo:
The result dummy global info.
"""
return DummyGlobalInfo() # type: ignore[attr-defined] # pylint:
disable=no-member
+
+
+def vdevice(target=None, vdevice_id: int = 0, memory_scope: str = "global") ->
VDevice:
+ """Create a virtual device global info.
+ Parameters
+ ----------
+ target
+ The target.
+ vdevice_id: int
+ The virtual device index.
+ memory_scope: str
+ The memory scope, default is "global"
+
+ Returns
+ -------
+ res : VDevice
+ The result virtual device.
+ """
+ return VDevice(target, vdevice_id, memory_scope) # type:
ignore[attr-defined] # pylint: disable=no-member
+
+
+def lookup_vdevice(target_kind: str = None, device_index: int = -1) -> VDevice:
+ """Retrieve a virtual device from the globalinfo vdevice list.
+ Parameters
+ ----------
+ target_kind: str
+ The target device kind, for example 'llvm' or 'cuda'.
+ device_index: int
+ The virtual device index.
+
+ Returns
+ -------
+ res : VDevice
+ The result virtual device.
+ """
+ return _ffi_api.LookupVDevice(target_kind, device_index) # type:
ignore[attr-defined] # pylint: disable=no-member
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 5bb0374d35..8a538c1868 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -140,6 +140,7 @@ from tvm.relax.op import (
tanh,
erf,
tile,
+ to_vdevice,
tril,
triu,
unique,
@@ -693,6 +694,7 @@ __all__ = [
"tan",
"tanh",
"tile",
+ "to_vdevice",
"tril",
"triu",
"tuple",
diff --git a/python/tvm/script/parser/ir/__init__.py
b/python/tvm/script/parser/ir/__init__.py
index ec518f8573..3a8196288d 100644
--- a/python/tvm/script/parser/ir/__init__.py
+++ b/python/tvm/script/parser/ir/__init__.py
@@ -27,4 +27,6 @@ __all__ = [
"module_global_infos",
"dummy_global_info",
"Range",
+ "lookup_vdevice",
+ "vdevice",
]
diff --git a/python/tvm/script/parser/relax/dist.py
b/python/tvm/script/parser/relax/dist.py
index 120d57ca56..f9c78f980f 100644
--- a/python/tvm/script/parser/relax/dist.py
+++ b/python/tvm/script/parser/relax/dist.py
@@ -78,7 +78,11 @@ def DTensor(
raise ValueError(f"shape must be a list or tuple, but got: {shape}")
if isinstance(device_mesh, str):
if not IRBuilder.is_in_scope():
- return (DTensorProxy(TensorProxy(shape, dtype, ndim),
DeviceMesh([], Range(0, 1)), ""),)
+ return (
+ DTensorProxy(
+ TensorProxy(shape, dtype, None, ndim), DeviceMesh([],
Range(0, 1)), ""
+ ),
+ )
name, index = device_mesh.split("[")
index = int(index[:-1])
frames = IRBuilder.current().frames
@@ -89,7 +93,7 @@ def DTensor(
assert isinstance(device_mesh, DeviceMesh)
if isinstance(placement, str):
placement = Placement.from_text(placement)
- return DTensorProxy(TensorProxy(shape, dtype, ndim), device_mesh,
placement)
+ return DTensorProxy(TensorProxy(shape, dtype, None, ndim), device_mesh,
placement)
__all__ = ["DTensor", "device_mesh"]
diff --git a/python/tvm/script/parser/relax/entry.py
b/python/tvm/script/parser/relax/entry.py
index ff237a5600..1c18d75be4 100644
--- a/python/tvm/script/parser/relax/entry.py
+++ b/python/tvm/script/parser/relax/entry.py
@@ -37,6 +37,7 @@ from tvm.runtime import ObjectGeneric
from tvm.tir import PrimExpr
from .._core import parse, utils
+from ..ir import lookup_vdevice
FType = TypeVar("FType", bound=_Callable)
@@ -103,12 +104,14 @@ def _eval_shape(expr: Union[str, PrimExpr], dict_globals:
Optional[Dict[str, Any
class TensorProxy(StructInfoProxy):
shape: Optional[List[Union[str, PrimExpr]]]
dtype: str
+ vdevice: Optional[str]
ndim: int
def __init__(
self,
shape: Optional[Union[List[Union[PrimExpr, str]], Expr]] = None,
dtype: Optional[str] = None,
+ vdevice: Optional[str] = None,
ndim: int = -1,
) -> None:
if isinstance(shape, Expr):
@@ -124,6 +127,7 @@ class TensorProxy(StructInfoProxy):
)
self.shape = shape
self.dtype = dtype
+ self.vdevice = vdevice
self.ndim = ndim
def get_symbolic_vars(self) -> Set[str]:
@@ -133,10 +137,18 @@ class TensorProxy(StructInfoProxy):
return {s for s in self.shape if isinstance(s, str) and
s.isidentifier()}
def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) ->
TensorStructInfo:
+ vdev = self.vdevice
+ if isinstance(self.vdevice, str):
+ if ":" in self.vdevice:
+ split_vdev = self.vdevice.split(":")
+ vdev = lookup_vdevice(split_vdev[0], int(split_vdev[1]))
+ else:
+ vdev = lookup_vdevice(self.vdevice, 0)
+
if self.shape is None:
- return TensorStructInfo(None, self.dtype, self.ndim)
+ return TensorStructInfo(None, self.dtype, vdev, self.ndim)
elif isinstance(self.shape, (ShapeExpr, Var)):
- return TensorStructInfo(self.shape, self.dtype, self.ndim)
+ return TensorStructInfo(self.shape, self.dtype, vdev, self.ndim)
else:
if dict_globals is None and any([isinstance(s, str) for s in
self.shape]):
raise ValueError(
@@ -144,12 +156,13 @@ class TensorProxy(StructInfoProxy):
"and return annotations for TVMScript."
)
shape = [_eval_shape(s, dict_globals) for s in self.shape]
- return TensorStructInfo(shape, self.dtype, self.ndim)
+ return TensorStructInfo(shape, self.dtype, vdev, self.ndim)
def Tensor(
shape: Optional[Union[List[Union[PrimExpr, str]], Expr]] = None,
dtype: Optional[str] = None,
+ vdevice: Optional[str] = None,
ndim: int = -1,
) -> TensorProxy:
# scalar tensor case
@@ -161,7 +174,7 @@ def Tensor(
if shape is not None and not isinstance(shape, (tuple, list)) and not
isinstance(shape, Expr):
raise ValueError(f"shape must be a list/tuple or an Expr, but got:
{shape}")
- return TensorProxy(shape, dtype, ndim)
+ return TensorProxy(shape, dtype, vdevice, ndim)
############################## R.Callable ##############################
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index d46fab7168..b7ba0ffe44 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -431,6 +431,23 @@ std::pair<IRModule, IRModule> SplitMixedModule(IRModule
mod_mixed, const Target&
return {host_mod, device_mod};
}
+/*!
+ * \brief Check and update host field of the given legacy heterogeneous
targets and
+ * target host.Note that this function is for legacy target api compatibility
issue only,
+ * not recommended for other use.
+ * \param ir_modules The pointer to a Map objects with keys being Target
objects
+ * \param host The Target typed object for target host to be updated
+ */
+void CheckAndUpdateHostConsistency(Map<Target, IRModule>* targets, Target*
host) {
+ Map<Target, IRModule> new_targets;
+ for (auto& it : *targets) {
+ auto target = it.first;
+ CheckAndUpdateHostConsistency(&target, host);
+ new_targets.Set(target, it.second);
+ }
+ *targets = new_targets;
+}
+
runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
const Target& target_host_arg) {
std::vector<runtime::Module> device_modules;
diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc
index 48f56d60d6..f1ecc8cd04 100644
--- a/src/ir/global_info.cc
+++ b/src/ir/global_info.cc
@@ -29,4 +29,17 @@
TVM_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() {
auto n = DummyGlobalInfo(make_object<DummyGlobalInfoNode>());
return n;
});
+
+VDevice::VDevice(Target tgt = {}, int dev_id = -1, MemoryScope mem_scope = {})
{
+ ObjectPtr<VDeviceNode> n = make_object<VDeviceNode>();
+ n->target = std::move(tgt);
+ n->vdevice_id = std::move(dev_id);
+ n->memory_scope = std::move(mem_scope);
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(VDeviceNode);
+TVM_REGISTER_GLOBAL("ir.VDevice").set_body_typed([](Target tgt, int dev_id,
MemoryScope mem_scope) {
+ return VDevice(tgt, dev_id, mem_scope);
+});
} // namespace tvm
diff --git a/src/relax/analysis/struct_info_analysis.cc
b/src/relax/analysis/struct_info_analysis.cc
index 7006f71198..82ccdf33ea 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -149,19 +149,24 @@ class WellDefinedEraser : public StructInfoMutator,
std::swap(has_undefined_, has_undefined);
}
+ VDevice vdev = VDevice();
+ if (op->vdevice.defined()) {
+ vdev = op->vdevice.value();
+ }
+
// erase symbolic shape if we have undefined.
if (!has_undefined) {
if (shape.same_as(op->shape)) {
return GetRef<StructInfo>(op);
} else {
if (shape.defined()) {
- return TensorStructInfo(shape.value(), op->dtype, op->span);
+ return TensorStructInfo(shape.value(), op->dtype, vdev, op->span);
} else {
- return TensorStructInfo(op->dtype, op->ndim, op->span);
+ return TensorStructInfo(op->dtype, op->ndim, vdev, op->span);
}
}
} else {
- return TensorStructInfo(op->dtype, op->ndim, op->span);
+ return TensorStructInfo(op->dtype, op->ndim, vdev, op->span);
}
}
@@ -767,6 +772,16 @@ class StructInfoLCAFinder
// find the target dtype and ndim.
DataType dtype = lhs->dtype == rhs->dtype ? lhs->dtype : DataType::Void();
int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownNDim;
+ VDevice vdev = VDevice();
+ if (lhs->vdevice.defined() && rhs->vdevice.defined()) {
+ if (lhs->vdevice.value().same_as(lhs->vdevice.value())) {
+ vdev = lhs->vdevice.value();
+ }
+ } else if (lhs->vdevice.defined()) {
+ vdev = lhs->vdevice.value();
+ } else if (rhs->vdevice.defined()) {
+ vdev = rhs->vdevice.value();
+ }
// if ndim mismatch or one side of shape is missing
// then we cannot keep in symbolic shape
if (lhs->ndim != rhs->ndim || !lhs->shape.defined() ||
!rhs->shape.defined() ||
@@ -775,12 +790,12 @@ class StructInfoLCAFinder
if (!lhs->shape.defined() && lhs->dtype == dtype && lhs->ndim == ndim) {
return GetRef<StructInfo>(lhs);
} else {
- return TensorStructInfo(dtype, ndim, lhs->span);
+ return TensorStructInfo(dtype, ndim, vdev, lhs->span);
}
}
// symbolic shape match but dtype mismatch
if (lhs->dtype != dtype) {
- return TensorStructInfo(lhs->shape.value(), dtype, lhs->span);
+ return TensorStructInfo(lhs->shape.value(), dtype, vdev, lhs->span);
} else {
return GetRef<StructInfo>(lhs);
}
diff --git a/src/relax/backend/vm/vm_builtin_lower.cc
b/src/relax/backend/vm/vm_builtin_lower.cc
index 6087c2bb25..784b3c9fd5 100644
--- a/src/relax/backend/vm/vm_builtin_lower.cc
+++ b/src/relax/backend/vm/vm_builtin_lower.cc
@@ -21,6 +21,7 @@
* \brief Lowers most builtin functions and packed calls.
*/
#include <tvm/relax/analysis.h>
+#include <tvm/relax/attrs/op.h>
#include <tvm/relax/backend.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/type.h>
@@ -46,6 +47,8 @@ class VMBuiltinLowerMutator : public ExprMutator {
return Reshape(call);
} else if (call->op == shape_of_op_) {
return ShapeOf(call);
+ } else if (call->op == to_vdevice_op_) {
+ return ToDevice(call);
} else if (call->op == make_closure_op_) {
return MakeClosure(call);
} else if (call->op == invoke_closure_op_) {
@@ -156,6 +159,22 @@ class VMBuiltinLowerMutator : public ExprMutator {
return Call(builtin_shape_of_, call_node->args, Attrs(),
{GetStructInfo(call_node)});
}
+ Expr ToDevice(const Call& call_node) {
+ // TODO(yongwww): replace ToVDeviceAttrs with related Expr
+ ICHECK(call_node->args.size() == 1);
+ ICHECK(call_node->struct_info_.defined());
+ auto attrs = call_node->attrs.as<ToVDeviceAttrs>();
+ Array<Expr> args;
+ args.push_back(call_node->args[0]);
+ // Get the DLDeviceType and device_id from VDevice
+ VDevice vdev = attrs->dst_vdevice;
+ int dev_type = vdev->target->GetTargetDeviceType();
+ int dev_id = vdev->vdevice_id;
+ args.push_back(PrimValue::Int64(dev_type));
+ args.push_back(PrimValue::Int64(dev_id));
+ return Call(builtin_to_device_, args, call_node->attrs,
{GetStructInfo(call_node)});
+ }
+
Expr MakeClosure(const Call& call_node) {
ICHECK(call_node->args.size() == 2);
ICHECK(call_node->args[0]->IsInstance<GlobalVarNode>());
@@ -198,6 +217,7 @@ class VMBuiltinLowerMutator : public ExprMutator {
const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
const Op& reshape_op_ = Op::Get("relax.reshape");
const Op& shape_of_op_ = Op::Get("relax.shape_of");
+ const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice");
const Op& make_closure_op_ = Op::Get("relax.make_closure");
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
const Op& alloc_tensor_op_ = Op::Get("relax.builtin.alloc_tensor");
@@ -214,6 +234,7 @@ class VMBuiltinLowerMutator : public ExprMutator {
const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"};
const ExternFunc builtin_reshape_{"vm.builtin.reshape"};
const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"};
+ const ExternFunc builtin_to_device_{"vm.builtin.to_device"};
const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
};
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index ccff18cd40..ac04096aaf 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -292,7 +292,7 @@ Constant::Constant(runtime::NDArray data,
Optional<StructInfo> struct_info_annot
n->struct_info_ = struct_info_annotation.value();
n->checked_type_ = GetStaticType(struct_info_annotation.value());
} else {
- TensorStructInfo tinfo(ShapeExpr(values), n->data.DataType(), span);
+ TensorStructInfo tinfo(ShapeExpr(values), n->data.DataType(), VDevice(),
span);
n->struct_info_ = tinfo;
n->checked_type_ = DynTensorType(tinfo->ndim, tinfo->dtype);
}
diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc
index c290711dcd..31784af000 100644
--- a/src/relax/ir/struct_info.cc
+++ b/src/relax/ir/struct_info.cc
@@ -92,7 +92,7 @@ TVM_REGISTER_GLOBAL("relax.ShapeStructInfo")
});
// Tensor
-TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Span span) {
+TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, VDevice
vdevice, Span span) {
ObjectPtr<TensorStructInfoNode> n = make_object<TensorStructInfoNode>();
// assign ndim before move
Optional<ShapeStructInfo> sinfo = MatchStructInfo<ShapeStructInfo>(shape);
@@ -104,15 +104,17 @@ TensorStructInfo::TensorStructInfo(Expr shape, DataType
dtype, Span span) {
// assign rest of the fields.
n->shape = std::move(shape);
n->dtype = dtype;
+ n->vdevice = vdevice;
n->span = span;
data_ = std::move(n);
}
-TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, Span span) {
+TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, VDevice vdevice,
Span span) {
ObjectPtr<TensorStructInfoNode> n = make_object<TensorStructInfoNode>();
CHECK_GE(ndim, -1) << "ndim of TensorStructInfo must be >= -1, but got " <<
ndim;
n->ndim = ndim;
n->dtype = dtype;
+ n->vdevice = vdevice;
n->span = span;
data_ = std::move(n);
}
@@ -120,12 +122,12 @@ TensorStructInfo::TensorStructInfo(DataType dtype, int
ndim, Span span) {
TVM_REGISTER_NODE_TYPE(TensorStructInfoNode);
TVM_REGISTER_GLOBAL("relax.TensorStructInfo")
- .set_body_typed([](Optional<Expr> shape, DataType dtype, int ndim, Span
span) {
+ .set_body_typed([](Optional<Expr> shape, DataType dtype, int ndim, VDevice
vdevice, Span span) {
if (shape.defined()) {
CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify shape
and ndim";
- return TensorStructInfo(shape.value(), dtype, span);
+ return TensorStructInfo(shape.value(), dtype, vdevice, span);
} else {
- return TensorStructInfo(dtype, ndim, span);
+ return TensorStructInfo(dtype, ndim, vdevice, span);
}
});
diff --git a/src/relax/ir/struct_info_functor.cc
b/src/relax/ir/struct_info_functor.cc
index 72ea623e07..c998d8c0b2 100644
--- a/src/relax/ir/struct_info_functor.cc
+++ b/src/relax/ir/struct_info_functor.cc
@@ -94,10 +94,15 @@ StructInfo StructInfoMutator::VisitStructInfo_(const
TensorStructInfoNode* op) {
shape = this->VisitStructInfoExprField(op->shape.value());
}
+ VDevice vdev = VDevice();
+ if (op->vdevice.defined()) {
+ vdev = op->vdevice.value();
+ }
+
if (shape.same_as(op->shape)) {
return GetRef<StructInfo>(op);
} else {
- return TensorStructInfo(shape.value(), op->dtype, op->span);
+ return TensorStructInfo(shape.value(), op->dtype, vdev, op->span);
}
}
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index af93b43dcf..1f4f7d5c34 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -866,5 +866,33 @@ Expr MakeStopLiftParams(Expr x) {
TVM_REGISTER_GLOBAL("relax.op.builtin.stop_lift_params").set_body_typed(MakeStopLiftParams);
+// to_vdevice
+TVM_REGISTER_NODE_TYPE(ToVDeviceAttrs);
+
+StructInfo InferToVDeviceStructInfo(const Call& call, const BlockBuilder& ctx)
{
+ ICHECK(call->args.size() == 1);
+ ICHECK(call->args[0]->struct_info_.defined());
+ TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+ return data_sinfo;
+}
+
+RELAY_REGISTER_OP("relax.to_vdevice")
+ .set_num_inputs(1)
+ .set_attrs_type<ToVDeviceAttrs>()
+ .add_argument("data", "Expr", "The input expression to be copied")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferToVDeviceStructInfo)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+Expr MakeToVDevice(Expr data, VDevice dst_vdevice) {
+ static const Op& op = Op::Get("relax.to_vdevice");
+ // TODO(@yongwww): replace Attr with TensorStructInfo
+ ObjectPtr<ToVDeviceAttrs> attrs = make_object<ToVDeviceAttrs>();
+ attrs->dst_vdevice = dst_vdevice;
+
+ return Call(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.to_vdevice").set_body_typed(MakeToVDevice);
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/transform/convert_layout.cc
b/src/relax/transform/convert_layout.cc
index 91dcd5d8e8..dd09dd67b8 100644
--- a/src/relax/transform/convert_layout.cc
+++ b/src/relax/transform/convert_layout.cc
@@ -267,7 +267,11 @@ class LayoutConvertMutator : public ExprMutator {
new_shape.push_back(
shape->values[from.LeafValue()->layout.IndexOf(to.LeafValue()->layout[i])]);
}
- return TensorStructInfo(ShapeExpr(new_shape), tsinfo->dtype,
tsinfo->span);
+ VDevice vdev = VDevice();
+ if (tsinfo->vdevice.defined()) {
+ vdev = tsinfo->vdevice.value();
+ }
+ return TensorStructInfo(ShapeExpr(new_shape), tsinfo->dtype, vdev,
tsinfo->span);
};
StructInfo new_struct_info = TransformTupleLeaf<LayoutDecision>(
binding->struct_info, std::array<NLayout, 2>({from_layout,
input_layout}), fvisitleaf);
diff --git a/src/relax/transform/to_mixed_precision.cc
b/src/relax/transform/to_mixed_precision.cc
index 64763276d0..d12d1080b9 100644
--- a/src/relax/transform/to_mixed_precision.cc
+++ b/src/relax/transform/to_mixed_precision.cc
@@ -289,7 +289,11 @@ class ToMixedPrecisionRewriter : public ExprMutator {
if (fp16_input_names_.count(var->name_hint())) {
auto sinfo = GetStructInfo(var);
if (auto tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
- TensorStructInfo fp16_sinfo(tensor_sinfo->shape.value(),
DataType::Float(16),
+ VDevice vdev = VDevice();
+ if (tensor_sinfo->vdevice.defined()) {
+ vdev = tensor_sinfo->vdevice.value();
+ }
+ TensorStructInfo fp16_sinfo(tensor_sinfo->shape.value(),
DataType::Float(16), vdev,
tensor_sinfo->span);
Var fp16_var(var->vid, fp16_sinfo, var->span);
var_remap_[var->vid] = fp16_var;
diff --git a/src/relay/backend/contrib/cmsisnn/target.cc
b/src/relay/backend/contrib/cmsisnn/target.cc
index f14c106703..527fba98c0 100644
--- a/src/relay/backend/contrib/cmsisnn/target.cc
+++ b/src/relay/backend/contrib/cmsisnn/target.cc
@@ -37,8 +37,8 @@ TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU)
.add_attr_option<Array<String>>("mattr")
.add_attr_option<String>("mcpu")
.add_attr_option<Bool>("debug_last_error")
- .set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
- .set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime)
+ .set_attr<relay::transform::FTVMRelayToTIR>(tvm::attr::kRelayToTIR,
RelayToTIR())
+ .set_attr<relay::transform::FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime)
.set_target_parser(tvm::target::parsers::cpu::ParseTarget);
} // namespace cmsisnn
diff --git a/src/relay/backend/contrib/codegen_c/target.cc
b/src/relay/backend/contrib/codegen_c/target.cc
index 623057ac17..cd1e0283df 100644
--- a/src/relay/backend/contrib/codegen_c/target.cc
+++ b/src/relay/backend/contrib/codegen_c/target.cc
@@ -34,7 +34,7 @@ namespace contrib {
*/
TVM_REGISTER_TARGET_KIND("ccompiler", kDLCPU)
.set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true))
- .set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, CCompilerPass())
+ .set_attr<relay::transform::FTVMRelayToTIR>(tvm::attr::kRelayToTIR,
CCompilerPass())
// Value is prepended to every output CModule.
.add_attr_option<String>("header", String(""));
diff --git a/src/relay/backend/contrib/cutlass/target.cc
b/src/relay/backend/contrib/cutlass/target.cc
index 7b377f340a..50c8b84a90 100644
--- a/src/relay/backend/contrib/cutlass/target.cc
+++ b/src/relay/backend/contrib/cutlass/target.cc
@@ -40,7 +40,7 @@ namespace cutlass {
*/
TVM_REGISTER_TARGET_KIND("cutlass", kDLCUDA)
.set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true))
- .set_attr<FTVMRelayToTIR>("RelayToTIR", CompileForCutlass())
+ .set_attr<tvm::transform::Pass>("RelayToTIR", CompileForCutlass())
// An integer specifying the compute capability. For example, 75 for
Turing and
// 80 or 86 for Ampere.
.add_attr_option<Integer>("sm", Integer(80))
diff --git a/src/relay/backend/contrib/ethosu/codegen.cc
b/src/relay/backend/contrib/ethosu/codegen.cc
index f35d4c6d48..2e635455e9 100644
--- a/src/relay/backend/contrib/ethosu/codegen.cc
+++ b/src/relay/backend/contrib/ethosu/codegen.cc
@@ -320,8 +320,8 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
TVM_REGISTER_TARGET_KIND("ethos-u", kDLCPU)
.set_attr<Bool>("use_device_api", Bool(true))
- .set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
- .set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime);
+ .set_attr<relay::transform::FTVMRelayToTIR>(tvm::attr::kRelayToTIR,
RelayToTIR())
+ .set_attr<relay::transform::FTVMTIRToRuntime>("TIRToRuntime",
TIRToRuntime);
} // namespace ethosu
} // namespace contrib
diff --git a/src/relay/backend/contrib/example_target_hooks/target.cc
b/src/relay/backend/contrib/example_target_hooks/target.cc
index b01c23ed80..275efaa933 100644
--- a/src/relay/backend/contrib/example_target_hooks/target.cc
+++ b/src/relay/backend/contrib/example_target_hooks/target.cc
@@ -33,8 +33,10 @@ runtime::Module TIRToRuntime(IRModule mod, Target target);
TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU)
.set_attr<Bool>("use_device_api", Bool(true))
- .set_attr<FTVMRelayToTIR>(attr::kRelayToTIR,
relay::contrib::example_target_hooks::RelayToTIR())
- .set_attr<FTVMTIRToRuntime>("TIRToRuntime",
relay::contrib::example_target_hooks::TIRToRuntime)
+ .set_attr<relay::transform::FTVMRelayToTIR>(attr::kRelayToTIR,
+
relay::contrib::example_target_hooks::RelayToTIR())
+ .set_attr<relay::transform::FTVMTIRToRuntime>(
+ "TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime)
.add_attr_option<Integer>("example_attribute", Integer(0));
} // namespace tvm
diff --git a/src/relay/backend/contrib/tensorrt/target.cc
b/src/relay/backend/contrib/tensorrt/target.cc
index 2e4581d30a..0277787a8c 100644
--- a/src/relay/backend/contrib/tensorrt/target.cc
+++ b/src/relay/backend/contrib/tensorrt/target.cc
@@ -39,7 +39,7 @@ namespace tensorrt {
*/
TVM_REGISTER_TARGET_KIND("tensorrt", kDLCUDA)
.set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true))
- .set_attr<FTVMRelayToTIR>("RelayToTIR", CompileForTensorRT())
+ .set_attr<tvm::transform::Pass>("RelayToTIR", CompileForTensorRT())
// A array of three integers given the major, minor, and patch numbers for
the supported
// TensorRT compiler version. If empty will be auto-detected from linked
library. Default empty.
.add_attr_option<Array<Integer>>("tensorrt_version", Array<Integer>())
diff --git a/src/relay/backend/contrib/uma/targets.cc
b/src/relay/backend/contrib/uma/targets.cc
index e2fe644cb9..d01f5b4c73 100644
--- a/src/relay/backend/contrib/uma/targets.cc
+++ b/src/relay/backend/contrib/uma/targets.cc
@@ -46,20 +46,20 @@
TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget")
}
}
- auto target_kind =
- TargetKindRegEntry::RegisterOrGet(target_name)
- .set_name()
- .set_default_device_type(kDLCPU)
- .add_attr_option<Array<String>>("keys")
- .add_attr_option<String>("tag")
- .add_attr_option<String>("device")
- .add_attr_option<String>("model")
- .add_attr_option<Array<String>>("libs")
- .add_attr_option<Target>("host")
- .add_attr_option<Integer>("from_device")
- .set_attr<FTVMRelayToTIR>(attr::kRelayToTIR,
-
relay::contrib::uma::RelayToTIR(target_name))
- .set_attr<FTVMTIRToRuntime>("TIRToRuntime",
relay::contrib::uma::TIRToRuntime);
+ auto target_kind = TargetKindRegEntry::RegisterOrGet(target_name)
+ .set_name()
+ .set_default_device_type(kDLCPU)
+ .add_attr_option<Array<String>>("keys")
+ .add_attr_option<String>("tag")
+ .add_attr_option<String>("device")
+ .add_attr_option<String>("model")
+ .add_attr_option<Array<String>>("libs")
+ .add_attr_option<Target>("host")
+ .add_attr_option<Integer>("from_device")
+ .set_attr<relay::transform::FTVMRelayToTIR>(
+ attr::kRelayToTIR,
relay::contrib::uma::RelayToTIR(target_name))
+ .set_attr<relay::transform::FTVMTIRToRuntime>(
+ "TIRToRuntime",
relay::contrib::uma::TIRToRuntime);
// target kind attrs inventory
auto kind = TargetKind::Get(target_name).value();
diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc
index 86f0152ce7..8b27bb2d9e 100644
--- a/src/runtime/relax_vm/builtin.cc
+++ b/src/runtime/relax_vm/builtin.cc
@@ -327,6 +327,12 @@
TVM_REGISTER_GLOBAL("vm.builtin.null_value").set_body([](TVMArgs args, TVMRetVal
*rv = nullptr;
});
+TVM_REGISTER_GLOBAL("vm.builtin.to_device")
+ .set_body_typed([](NDArray data, int dev_type, int dev_id) {
+ Device dst_device = {(DLDeviceType)dev_type, dev_id};
+ return data.CopyTo(dst_device);
+ });
+
/*!
* \brief Load the scalar value in cond and return the result value.
* \param cond The condition
diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc
index fb51886a7d..2f2785ca44 100644
--- a/src/script/ir_builder/ir/ir.cc
+++ b/src/script/ir_builder/ir/ir.cc
@@ -110,11 +110,41 @@ void ModuleGlobalInfos(Map<String, Array<GlobalInfo>>
global_infos) {
}
}
+VDevice LookupVDevice(String target_kind, int device_index) {
+ if (IRBuilder::IsInScope()) {
+ IRModuleFrame frame = FindModuleFrame();
+ if (frame->global_infos.empty()) {
+ LOG(FATAL) << "ValueError: The GlobalInfos in the IRModule is not
defined.";
+ }
+ Array<GlobalInfo> vdevices = frame->global_infos["vdevice"];
+ if (vdevices.empty() || device_index < 0 ||
+ static_cast<size_t>(device_index) >= vdevices.size()) {
+ LOG(FATAL) << "ValueError: The target VDevice in the GlobalInfos was not
found.";
+ }
+ if (target_kind == "vdevice") {
+ return Downcast<VDevice>(vdevices[device_index]);
+ }
+ int count = 0;
+ for (auto vdevice : vdevices) {
+ auto vdev = Downcast<VDevice>(vdevice);
+ if (vdev->target->kind->name == target_kind) {
+ if (count == device_index) {
+ return vdev;
+ }
+ count++;
+ }
+ }
+ }
+ LOG(WARNING) << "The annotated device was not found, please check your
vdevice list.";
+ return VDevice();
+}
+
TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule);
TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction);
TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction);
TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs);
TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGlobalInfos").set_body_typed(ModuleGlobalInfos);
+TVM_REGISTER_GLOBAL("script.ir_builder.ir.LookupVDevice").set_body_typed(LookupVDevice);
} // namespace ir
} // namespace ir_builder
diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc
index ecf92897a5..a239481d03 100644
--- a/src/script/printer/ir/ir.cc
+++ b/src/script/printer/ir/ir.cc
@@ -118,6 +118,16 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return IR(d, "dummy_global_info")->Call({});
});
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<VDevice>("", [](VDevice vdev, ObjectPath p, IRDocsifier d)
-> Doc {
+ d->AddGlobalInfo("vdevice", vdev);
+ Map<String, ObjectRef> config = vdev->target->Export();
+ return IR(d, "vdevice")
+ ->Call({d->AsDoc<ExprDoc>(config, p),
+ LiteralDoc::Int(vdev->vdevice_id, p->Attr("vdevice_id")),
+ LiteralDoc::Str(vdev->memory_scope,
p->Attr("memory_scope"))});
+ });
+
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<Op>("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc {
return IR(d, "Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))});
diff --git a/src/script/printer/ir_docsifier.cc
b/src/script/printer/ir_docsifier.cc
index 521ab07359..a424863495 100644
--- a/src/script/printer/ir_docsifier.cc
+++ b/src/script/printer/ir_docsifier.cc
@@ -63,6 +63,12 @@ ExprDoc IRDocsifierNode::AddMetadata(const ObjectRef& obj) {
return IdDoc("metadata")[{LiteralDoc::Str(key,
NullOpt)}][{LiteralDoc::Int(index, NullOpt)}];
}
+void IRDocsifierNode::AddGlobalInfo(const String& name, const GlobalInfo&
ginfo) {
+ ICHECK(ginfo.defined()) << "TypeError: Cannot add nullptr to global_infos";
+ Array<GlobalInfo>& array = global_infos[name];
+ array.push_back(ginfo);
+}
+
bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj) const { return
obj2info.count(obj); }
void IRDocsifierNode::RemoveVar(const ObjectRef& obj) {
diff --git a/src/script/printer/relax/struct_info.cc
b/src/script/printer/relax/struct_info.cc
index 49162bb824..7fab5b59a2 100644
--- a/src/script/printer/relax/struct_info.cc
+++ b/src/script/printer/relax/struct_info.cc
@@ -111,6 +111,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
kwargs_keys.push_back("ndim");
kwargs_values.push_back(LiteralDoc::Int(n->ndim,
n_p->Attr("ndim")));
}
+ if (n->vdevice.defined()) {
+ kwargs_keys.push_back("vdevice");
+ std::string dev_kind = n->vdevice.value()->target->kind->name;
+ int dev_index = FindVDeviceIndexByTargetKind(n->vdevice.value(),
d);
+ kwargs_values.push_back(
+ LiteralDoc::Str(dev_kind + ":" + std::to_string(dev_index),
n_p->Attr("vdevice")));
+ }
if (args.empty() && kwargs_keys.empty()) {
return Relax(d, "Tensor");
}
diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h
index 88fc7491c2..e0b5348d73 100644
--- a/src/script/printer/relax/utils.h
+++ b/src/script/printer/relax/utils.h
@@ -97,6 +97,22 @@ Array<StmtDoc> PrintSeqExpr(const relax::SeqExpr& n, const
ObjectPath& n_p, cons
ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const
IRDocsifier& d);
+inline int FindVDeviceIndexByTargetKind(const VDevice& vdevice, const
IRDocsifier& d) {
+ Array<GlobalInfo> vdevices = d->global_infos["vdevice"];
+ int kind_index = 0;
+ for (size_t i = 0; i < vdevices.size(); ++i) {
+ auto vdev = Downcast<VDevice>(vdevices[i]);
+ if (vdev.same_as(vdevice)) {
+ return kind_index;
+ }
+ if (vdev->target->kind->name == vdevice->target->kind->name) {
+ kind_index++;
+ }
+ }
+ LOG(WARNING) << "The VDevice was not found in the global_infos map: " <<
vdevice;
+ return -1;
+}
+
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/src/target/codegen.cc b/src/target/codegen.cc
index bbb2c15a64..55af8889e1 100644
--- a/src/target/codegen.cc
+++ b/src/target/codegen.cc
@@ -38,6 +38,9 @@
#include <vector>
namespace tvm {
+
+using FTVMTIRToRuntime = runtime::TypedPackedFunc<runtime::Module(IRModule,
Target)>;
+
namespace codegen {
runtime::Module Build(IRModule mod, Target target) {
diff --git a/src/target/target.cc b/src/target/target.cc
index 2f585188d0..cd2e3714e4 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -21,6 +21,7 @@
* \file src/target/target.cc
*/
#include <dmlc/thread_local.h>
+#include <tvm/ir/transform.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
@@ -91,16 +92,6 @@ void CheckAndUpdateHostConsistency(Target* target, Target*
host) {
*host = (*target)->GetHost().value_or(Target());
}
-void CheckAndUpdateHostConsistency(Map<Target, IRModule>* targets, Target*
host) {
- Map<Target, IRModule> new_targets;
- for (auto& it : *targets) {
- auto target = it.first;
- CheckAndUpdateHostConsistency(&target, host);
- new_targets.Set(target, it.second);
- }
- *targets = new_targets;
-}
-
static std::vector<String> DeduplicateKeys(const std::vector<String>& keys) {
std::vector<String> new_keys;
for (size_t i = 0; i < keys.size(); ++i) {
@@ -614,8 +605,8 @@ Target::Target(TargetKind kind, Optional<ObjectRef> host,
String tag, Array<Stri
bool Target::IsExternalCodegen() const {
TargetKindAttrMap<Bool> is_external_codegen_map =
TargetKind::GetAttrMap<Bool>(tvm::attr::kIsExternalCodegen);
- TargetKindAttrMap<FTVMRelayToTIR> relay_to_tir_map =
- TargetKind::GetAttrMap<FTVMRelayToTIR>(tvm::attr::kRelayToTIR);
+ TargetKindAttrMap<tvm::transform::Pass> relay_to_tir_map =
+ TargetKind::GetAttrMap<tvm::transform::Pass>(tvm::attr::kRelayToTIR);
return is_external_codegen_map.get(get()->kind, Bool(false)) ||
relay_to_tir_map.count(get()->kind);
}
diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc
index 37a8eeb448..9d05023e79 100644
--- a/tests/cpp/target_test.cc
+++ b/tests/cpp/target_test.cc
@@ -458,7 +458,7 @@ TVM_REGISTER_TARGET_KIND("test_external_codegen_2",
kDLMetal)
.set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true));
TVM_REGISTER_TARGET_KIND("test_external_codegen_3", kDLCPU)
- .set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR,
tvm::relay::transform::InferType());
+ .set_attr<tvm::transform::Pass>(tvm::attr::kRelayToTIR,
tvm::relay::transform::InferType());
TEST(Target, ExternalCodegen) {
Target regular("cuda");
diff --git a/python/tvm/ir/global_info.py
b/tests/python/relax/test_json_compact.py
similarity index 53%
copy from python/tvm/ir/global_info.py
copy to tests/python/relax/test_json_compact.py
index 17011e76a6..1320ff1cd6 100644
--- a/python/tvm/ir/global_info.py
+++ b/tests/python/relax/test_json_compact.py
@@ -14,29 +14,37 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Global Info."""
-import tvm
-from tvm.runtime.object import Object
-from . import _ffi_api
-
-class GlobalInfo(Object):
- """Base node for all global info that can appear in the IR"""
-
- def __eq__(self, other):
- """Compare two struct info for structural equivalence."""
- return tvm.ir.structural_equal(self, other)
+import tvm
+import tvm.testing
+from tvm import relax
+import json
- def __ne__(self, other):
- return not self.__eq__(other)
- def same_as(self, other):
- """Overload with structural equality."""
- return super().__eq__(other)
+# 0.13 BACKWARDS COMPATIBILITY TESTS
+def test_vdevice():
+ nodes = [
+ {"type_key": ""},
+ {
+ "type_key": "relax.TensorStructInfo",
+ "attrs": {
+ "dtype": "float32",
+ "ndim": "-1",
+ "shape": "0",
+ "span": "0",
+ },
+ },
+ ]
+ data = {
+ "root": 1,
+ "nodes": nodes,
+ "attrs": {"tvm_version": "0.13.0"},
+ "b64ndarrays": [],
+ }
+ tsinfo = tvm.ir.load_json(json.dumps(data))
+ assert isinstance(tsinfo, relax.TensorStructInfo)
+ assert not tsinfo.vdevice
-class DummyGlobalInfo(GlobalInfo):
- def __init__(self) -> None:
- self.__init_handle_by_constructor__(
- _ffi_api.DummyGlobalInfo,
- )
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_relax_operators.py
b/tests/python/relax/test_relax_operators.py
index 90608df4b6..e0904477d4 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -247,5 +247,35 @@ def test_op_call_pure_packed():
assert (copy_found.numpy() == arr).all()
+def test_op_to_device():
+ @tvm.script.ir_module
+ class CallToDevice:
+ @R.function
+ def to_dev(x: R.Tensor((3, 4), "float32")):
+ z = R.call_pure_packed(
+ "vm.builtin.to_device",
+ x,
+ 1,
+ 0,
+ sinfo_args=(R.Tensor((3, 4), dtype="float32")),
+ )
+ return z
+
+ np.random.seed(0) # to avoid flakiness
+ arr = np.random.rand(3, 4).astype("float32")
+ copy_found = run_cpu(CallToDevice, "to_dev", tvm.nd.array(arr))
+ assert (copy_found.numpy() == arr).all()
+
+
+def test_op_to_vdevice():
+ @tvm.script.ir_module
+ class ToVDevice:
+ @R.function
+ def to_vdev(x: R.Tensor((3, 4), "float32")):
+ dst_vdev = tvm.ir.VDevice("llvm", 0, "global")
+ ret = R.to_vdevice(x, dst_vdev)
+ return ret
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index bc324fe364..39a4d33ca6 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -22,7 +22,7 @@ import tvm
import tvm.script
import tvm.testing
from tvm import IRModule, relax, tir, topi
-from tvm.ir import DummyGlobalInfo
+from tvm.ir import VDevice, DummyGlobalInfo
from tvm.script.parser import ir as I
from tvm.script.parser import relax as R
from tvm.script.parser import tir as T
@@ -303,6 +303,56 @@ def test_module_with_attr_and_global_info():
_check(TestModule, mod)
+def test_global_info_vdevice():
+ vdevices = [
+ VDevice("llvm"),
+ VDevice("cuda", 0),
+ VDevice("cuda -arch=sm_80", 0),
+ VDevice("metal", 0, "global"),
+ ]
+
+ @I.ir_module
+ class TestModule:
+ I.module_attrs({"attr": 10})
+ I.module_global_infos(
+ {
+ "vdevice": [
+ I.vdevice("llvm"),
+ I.vdevice("cuda", 0),
+ I.vdevice("cuda -arch=sm_80", 0),
+ I.vdevice("metal", 0, "global"),
+ ]
+ }
+ )
+
+ @T.prim_func(private=True)
+ def tir_func(
+ x: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+ y: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i, j in T.grid(T.int64(128), T.int64(128)):
+ with T.block():
+ vi, vj = T.axis.remap("SS", [i, j])
+ y[vi, vj] = x[vi, vj] + 1.0
+
+ @R.function
+ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128),
"float32"):
+ cls = TestModule
+ gv0 = R.call_tir(cls.tir_func, x, R.Tensor((128, 128),
dtype="float32"))
+ return gv0
+
+ x = relax.Var("x", R.Tensor((128, 128), "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", (x,)):
+ out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func")
+ bb.emit_func_output(out)
+ mod = bb.get()
+ mod.update_global_info("vdevice", vdevices)
+ mod = mod.with_attr("attr", tvm.tir.IntImm("int32", 10))
+ _check(TestModule, mod)
+
+
def test_relax_tensor_op():
@R.function
def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor((4, 4), "float32"):
@@ -714,6 +764,51 @@ def test_tensor_type_without_args():
_check(foo, bb.get()["foo"])
+def test_tensor_with_vdevice():
+ vdevices = [
+ VDevice("llvm"),
+ VDevice("cuda", 0),
+ VDevice("metal", 0, "global"),
+ VDevice("cuda -arch=sm_80", 0),
+ ]
+
+ @I.ir_module
+ class TestModule:
+ I.module_attrs({"attr": 10})
+ I.module_global_infos(
+ {
+ "vdevice": [
+ I.vdevice("llvm"),
+ I.vdevice("cuda", 0),
+ I.vdevice("metal", 0, "global"),
+ I.vdevice("cuda -arch=sm_80", 0),
+ ]
+ }
+ )
+
+ @R.function
+ def foo(
+ a: R.Tensor((128, 128), "float32", "cuda:1"), # noqa: F722
+ b: R.Tensor((128, 128), "float32", "llvm"),
+ c: R.Tensor((128, 128), "float32", "vdevice:3"), # noqa: F722
+ ) -> R.Tensor((128, 128), "float32"):
+ s = R.add(a, c)
+ return s
+
+ a = relax.Var("a", R.Tensor((128, 128), "float32", vdevices[3]))
+ b = relax.Var("b", R.Tensor((128, 128), "float32", vdevices[0]))
+ c = relax.Var("c", R.Tensor((128, 128), "float32", vdevices[3]))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", (a, b, c)):
+ out = bb.emit(relax.op.add(a, c))
+ bb.emit_func_output(out)
+ mod = bb.get()
+ mod = mod.with_attr("attr", tvm.tir.IntImm("int32", 10))
+ mod.update_global_info("vdevice", vdevices)
+
+ _check(TestModule, mod)
+
+
def test_direct_return():
@R.function
def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"):
diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py
b/tests/python/relax/test_tvmscript_parser_op_manipulate.py
index a80e8aad37..25f4f08520 100644
--- a/tests/python/relax/test_tvmscript_parser_op_manipulate.py
+++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py
@@ -403,5 +403,20 @@ def test_flip():
_check(foo, bb.get()["foo"])
+def test_to_vdevice():
+ @R.function
+ def foo(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ tensor = R.to_vdevice(x, tvm.ir.VDevice("llvm", 0, "global"))
+ return tensor
+
+ x = relax.Var("x", R.Tensor((), "int32"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", (x,)):
+ tensor = bb.emit(relax.op.to_vdevice(x, tvm.ir.VDevice("llvm", 0,
"global")))
+ bb.emit_func_output(tensor)
+
+ _check(foo, bb.get()["foo"])
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_vm_build.py
b/tests/python/relax/test_vm_build.py
index 0b34d24540..085b6137ac 100644
--- a/tests/python/relax/test_vm_build.py
+++ b/tests/python/relax/test_vm_build.py
@@ -725,6 +725,41 @@ def test_recursion(exec_mode):
tvm.testing.assert_allclose(res.numpy(), np.power(2.0, recursion_runs),
rtol=1e-7, atol=1e-7)
[email protected]_gpu
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_to_device(exec_mode):
+ @tvm.script.ir_module
+ class TestToVDevice:
+ @R.function
+ def foo1(
+ x: R.Tensor((2, 3), "float32"),
+ ) -> R.Tensor((2, 3), "float32"):
+ copied = R.to_vdevice(x, tvm.ir.VDevice("cuda", 0, "global"))
+ return copied
+
+ @R.function
+ def foo2(
+ x: R.Tensor((2, 3), "float32"),
+ ) -> R.Tensor((2, 3), "float32"):
+ copied = R.to_vdevice(x, tvm.ir.VDevice("llvm", 0, "global"))
+ return copied
+
+ mod = TestToVDevice
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.build(mod, target, exec_mode=exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ x_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32"))
+ res_1 = check_saved_func(vm, "foo1", x_inp)
+ res_2 = check_saved_func(vm, "foo2", x_inp)
+
+ # check the copied tensor's device
+ assert str(res_1.device) == "cuda(0)"
+ assert str(res_2.device) == "cpu(0)"
+
+ tvm.testing.assert_allclose(res_1.numpy(), x_inp.numpy())
+ tvm.testing.assert_allclose(res_2.numpy(), x_inp.numpy())
+
+
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_vm_closure(exec_mode):
@tvm.script.ir_module
diff --git a/tests/python/relax/test_vm_codegen_only.py
b/tests/python/relax/test_vm_codegen_only.py
index ffa9837d02..d9fb130f3c 100644
--- a/tests/python/relax/test_vm_codegen_only.py
+++ b/tests/python/relax/test_vm_codegen_only.py
@@ -57,6 +57,33 @@ def test_vm_copy(exec_mode):
tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7)
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_to_device(exec_mode):
+ @tvm.script.ir_module
+ class TestVMToDevice:
+ @R.function
+ def foo(x: R.Tensor((3, 4), "float32")):
+ R.func_attr({"global_symbol": "foo"})
+ # Copy x to the first cpu: device_type=1 and device_id=0.
+ # More device info. please take a look at
python/tvm/_ffi/runtime_ctypes.py
+ z = R.call_packed(
+ "vm.builtin.to_device", x, 1, 0, sinfo_args=(R.Tensor((3, 4),
dtype="float32"))
+ )
+ return z
+
+ mod = TestVMToDevice
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = codegen(mod, target, exec_mode)
+ inp = tvm.nd.array(np.random.rand(3, 4).astype(np.float32))
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ res = check_saved_func(vm, "foo", inp)
+ tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7)
+ # check the resulting tensor is on cpu:0
+ assert str(res.device) == "cpu(0)"
+ assert res.device.device_type == 1
+ assert res.device.device_id == 0
+
+
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_if_cond_const(exec_mode):
@tvm.script.ir_module