This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 3f1215349e [Unity] ensure memory.alloc_tensor/storage roundtrippable 
(#14226)
3f1215349e is described below

commit 3f1215349eeed3e065a8da2bd479ef6d05b33133
Author: Yong Wu <[email protected]>
AuthorDate: Tue Mar 7 12:20:50 2023 -0800

    [Unity] ensure memory.alloc_tensor/storage roundtrippable (#14226)
    
    - make R.memory.alloc_tensor/alloc_storage roundtrippable
    - Expose relax.vm ops to python side
---
 python/tvm/relax/op/memory/memory.py               | 34 +++++++---
 python/tvm/relax/op/vm/__init__.py                 | 20 ++++++
 python/tvm/relax/op/vm/_ffi_api.py                 | 19 ++++++
 python/tvm/relax/op/{memory/memory.py => vm/vm.py} | 73 +++++++++++-----------
 python/tvm/relay/op/op_attrs.py                    | 10 ---
 python/tvm/script/ir_builder/relax/ir.py           |  5 +-
 src/relax/op/op.cc                                 | 39 ++++++------
 tests/python/relax/test_tvmscript_parser.py        | 27 ++++++--
 8 files changed, 145 insertions(+), 82 deletions(-)

diff --git a/python/tvm/relax/op/memory/memory.py 
b/python/tvm/relax/op/memory/memory.py
index b58b987d2a..7b84ffc48b 100644
--- a/python/tvm/relax/op/memory/memory.py
+++ b/python/tvm/relax/op/memory/memory.py
@@ -15,13 +15,19 @@
 # specific language governing permissions and limitations
 """Relax memory primitives."""
 
+from typing import Union
 from . import _ffi_api
-from ...expr import Expr, Call
+from ...expr import Expr, Call, PrimValue, DataTypeImm, StringImm
 from ...utils import args_converter
 
 
 @args_converter.auto
-def alloc_storage(size: Expr, virtual_device_index: int, storage_scope: str, 
dtype: str) -> Call:
+def alloc_storage(
+    size: Expr,
+    virtual_device_index: Union[int, Expr],
+    storage_scope: Union[str, Expr],
+    dtype: Union[str, Expr],
+) -> Call:
     """Construct a Call to allocate a storage with specific size, 
virtual_device_index,
     storage_scope and dtype.
 
@@ -30,14 +36,14 @@ def alloc_storage(size: Expr, virtual_device_index: int, 
storage_scope: str, dty
     size : Expr
         The size of the storage to be allocated.
 
-    virtual_device_index : int
+    virtual_device_index : Union[int, Expr]
         The virtual device index indicating on which device the storage is to 
be allocated.
         Index -1 is reserved for the host device.
 
-    storage_scope : str
+    storage_scope : Union[str, Expr]
         The storage scope to allocate the storage to.
 
-    dtype : str
+    dtype : Union[str, Expr]
         The datatype of the storage to be allocated.
 
     Returns
@@ -45,11 +51,19 @@ def alloc_storage(size: Expr, virtual_device_index: int, 
storage_scope: str, dty
     result : Call
         A relax Call, which gets the allocated storage.
     """
+    if isinstance(dtype, str):
+        dtype = DataTypeImm(dtype)
+    if isinstance(storage_scope, str):
+        storage_scope = StringImm(storage_scope)
+    if isinstance(virtual_device_index, int):
+        virtual_device_index = PrimValue(virtual_device_index)
     return _ffi_api.alloc_storage(size, virtual_device_index, storage_scope, 
dtype)  # type: ignore
 
 
 @args_converter.auto
-def alloc_tensor(storage: Expr, offset: int, shape: Expr, dtype: str) -> Call:
+def alloc_tensor(
+    storage: Expr, offset: Union[int, Expr], shape: Expr, dtype: Union[str, 
Expr]
+) -> Call:
     """Construct a Call to allocate a tensor on a certain storage starting 
from the given offset.
 
     Parameters
@@ -57,13 +71,13 @@ def alloc_tensor(storage: Expr, offset: int, shape: Expr, 
dtype: str) -> Call:
     storage : Expr
         The storage to allocate the tensor to.
 
-    offset : int
+    offset : Union[int, Expr]
         The storage offset to allocate the tensor.
 
     shape : Expr
         The shape of the tensor to be allocated.
 
-    dtype : str
+    dtype : Union[str, Expr]
         The datatype of the tensor to be allocated.
 
     Returns
@@ -71,6 +85,10 @@ def alloc_tensor(storage: Expr, offset: int, shape: Expr, 
dtype: str) -> Call:
     result : Call
         A relax Call, which gets the allocated tensor.
     """
+    if isinstance(offset, int):
+        offset = PrimValue(offset)
+    if isinstance(dtype, str):
+        dtype = DataTypeImm(dtype)
     return _ffi_api.alloc_tensor(storage, offset, shape, dtype)  # type: ignore
 
 
diff --git a/python/tvm/relax/op/vm/__init__.py 
b/python/tvm/relax/op/vm/__init__.py
new file mode 100644
index 0000000000..ecb2857a89
--- /dev/null
+++ b/python/tvm/relax/op/vm/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=wildcard-import, redefined-builtin
+"""Relax vm primitives."""
+
+from .vm import *
diff --git a/python/tvm/relax/op/vm/_ffi_api.py 
b/python/tvm/relax/op/vm/_ffi_api.py
new file mode 100644
index 0000000000..786b73c76c
--- /dev/null
+++ b/python/tvm/relax/op/vm/_ffi_api.py
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+"""FFI APIs for tvm.relax.op.vm"""
+import tvm._ffi
+
+tvm._ffi._init_api("relax.op.vm", __name__)
diff --git a/python/tvm/relax/op/memory/memory.py b/python/tvm/relax/op/vm/vm.py
similarity index 58%
copy from python/tvm/relax/op/memory/memory.py
copy to python/tvm/relax/op/vm/vm.py
index b58b987d2a..89d31b6581 100644
--- a/python/tvm/relax/op/memory/memory.py
+++ b/python/tvm/relax/op/vm/vm.py
@@ -13,31 +13,33 @@
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
-"""Relax memory primitives."""
+"""Relax vm primitives."""
 
+from typing import Union
 from . import _ffi_api
-from ...expr import Expr, Call
+from ...expr import Expr, Call, PrimValue, DataTypeImm, Tuple
 from ...utils import args_converter
 
 
 @args_converter.auto
-def alloc_storage(size: Expr, virtual_device_index: int, storage_scope: str, 
dtype: str) -> Call:
-    """Construct a Call to allocate a storage with specific size, 
virtual_device_index,
-    storage_scope and dtype.
+def alloc_storage(
+    size: Expr,
+    runtime_device_index: Union[int, Expr],
+    dtype: Union[str, Expr],
+) -> Call:
+    """Construct a Call to allocate a storage with specific size,
+    runtime_device_index, and dtype.
 
     Parameters
     ----------
     size : Expr
         The size of the storage to be allocated.
 
-    virtual_device_index : int
-        The virtual device index indicating on which device the storage is to 
be allocated.
-        Index -1 is reserved for the host device.
+    runtime_device_index : Union[int, Expr]
+        The device index indicating on which device the tensor is to
+        be allocated at runtime. Index -1 is reserved for the host device.
 
-    storage_scope : str
-        The storage scope to allocate the storage to.
-
-    dtype : str
+    dtype : Union[str, Expr]
         The datatype of the storage to be allocated.
 
     Returns
@@ -45,11 +47,17 @@ def alloc_storage(size: Expr, virtual_device_index: int, 
storage_scope: str, dty
     result : Call
         A relax Call, which gets the allocated storage.
     """
-    return _ffi_api.alloc_storage(size, virtual_device_index, storage_scope, 
dtype)  # type: ignore
+    if isinstance(dtype, str):
+        dtype = DataTypeImm(dtype)
+    if isinstance(runtime_device_index, int):
+        runtime_device_index = PrimValue(runtime_device_index)
+    return _ffi_api.alloc_storage(size, runtime_device_index, dtype)  # type: 
ignore
 
 
 @args_converter.auto
-def alloc_tensor(storage: Expr, offset: int, shape: Expr, dtype: str) -> Call:
+def alloc_tensor(
+    storage: Expr, offset: Union[int, Expr], shape: Expr, dtype: Union[str, 
Expr]
+) -> Call:
     """Construct a Call to allocate a tensor on a certain storage starting 
from the given offset.
 
     Parameters
@@ -57,13 +65,13 @@ def alloc_tensor(storage: Expr, offset: int, shape: Expr, 
dtype: str) -> Call:
     storage : Expr
         The storage to allocate the tensor to.
 
-    offset : int
+    offset : Union[int, Expr]
         The storage offset to allocate the tensor.
 
     shape : Expr
         The shape of the tensor to be allocated.
 
-    dtype : str
+    dtype : Union[str, Expr]
         The datatype of the tensor to be allocated.
 
     Returns
@@ -71,38 +79,31 @@ def alloc_tensor(storage: Expr, offset: int, shape: Expr, 
dtype: str) -> Call:
     result : Call
         A relax Call, which gets the allocated tensor.
     """
+    if isinstance(offset, int):
+        offset = PrimValue(offset)
+    if isinstance(dtype, str):
+        dtype = DataTypeImm(dtype)
     return _ffi_api.alloc_tensor(storage, offset, shape, dtype)  # type: ignore
 
 
 @args_converter.auto
-def kill_storage(storage: Expr) -> Call:
+def call_tir_dyn(func: Expr, args: Tuple) -> Call:
     """Construct a Call to kill a storage.
 
     Parameters
     ----------
-    storage : Expr
+    func : Expr
         The storage to be killed.
 
-    Returns
-    -------
-    result : Call
-        A relax Call to kill a storage.
-    """
-    return _ffi_api.kill_storage(storage)  # type: ignore
-
-
-@args_converter.auto
-def kill_tensor(tensor: Expr) -> Call:
-    """Construct a Call to kill a tensor.
-
-    Parameters
-    ----------
-    tensor : Expr
-        The tensor to be killed.
+    args : Tuple
+        The input args, includes a list of tensors, and a ShapeExpr.
 
     Returns
     -------
     result : Call
-        A relax Call to kill a tensor.
+        A relax Call to call_tir_dyn.
     """
-    return _ffi_api.kill_tensor(tensor)  # type: ignore
+    if isinstance(args, (list, tuple)):
+        args = Tuple(args)
+
+    return _ffi_api.call_tir_dyn(func, args)  # type: ignore
diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py
index 4e9a9a4707..deae9e2f48 100644
--- a/python/tvm/relay/op/op_attrs.py
+++ b/python/tvm/relay/op/op_attrs.py
@@ -439,16 +439,6 @@ class AffineGridAttrs(Attrs):
     """Attributes used in affine_grid operators"""
 
 
-@tvm._ffi.register_object("relay.attrs.AllocStorageAttrs")
-class AllocStorageAttrs(Attrs):
-    """Attributes used in alloc_storage operators"""
-
-
-@tvm._ffi.register_object("relay.attrs.AllocTensorAttrs")
-class AllocTensorAttrs(Attrs):
-    """Attributes used in alloc_tensor operators"""
-
-
 @tvm._ffi.register_object("relay.attrs.CastHintAttrs")
 class CastHintAttrs(Attrs):
     """Attributes used in cast_hint annotation operators"""
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 14ef36307a..03f1c1db46 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union, 
Callable
 import tvm
 from tvm import DataType, relax
 from tvm.ir import PrimExpr
-from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, VarBinding, 
const
+from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, ShapeExpr, Var, 
VarBinding, const
 from tvm.relax.block_builder import BlockBuilder as rx_bb
 
 ############################### Operators ###############################
@@ -113,6 +113,7 @@ from tvm.relax.op import (
     tril,
     triu,
     unique,
+    vm,
     where,
     zeros,
     zeros_like,
@@ -598,6 +599,7 @@ __all__ = [
     "round",
     "shape",
     "shape_of",
+    "ShapeExpr",
     "std",
     "str",
     "strided_slice",
@@ -621,6 +623,7 @@ __all__ = [
     "tuple",
     "unique",
     "variance",
+    "vm",
     "where",
     "zeros",
     "zeros_like",
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index 21d692b6a4..cf6343b2fa 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -296,21 +296,18 @@ RELAY_REGISTER_OP("relax.memory.alloc_storage")
     .set_num_inputs(4)
     .add_argument("total_space", "Expr", "The total space of the storage to 
allocate.")
     .add_argument(
-        "virtual_device_index", "int64_t",
+        "virtual_device_index", "PrimValue",
         "The virtual device index indicating on which device the storage is to 
be allocated, "
         "Index -1 is reserved for the host device.")
-    .add_argument("storage_scope", "string",
+    .add_argument("storage_scope", "StringImm",
                   "The storage scope of the storage to allocate. Default is 
global.")
-    .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.")
+    .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to 
allocate.")
     .set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo);
 
-Expr MakeAllocStorage(Expr size, int64_t virtual_device_index, std::string 
storage_scope,
-                      DataType dtype) {
+Expr MakeAllocStorage(Expr size, PrimValue virtual_device_index, StringImm 
storage_scope,
+                      DataTypeImm dtype) {
   static const Op& op = Op::Get("relax.memory.alloc_storage");
-  return Call(
-      op,
-      {size, PrimValue::Int64(virtual_device_index), StringImm(storage_scope), 
DataTypeImm(dtype)},
-      Attrs(), {});
+  return Call(op, {size, virtual_device_index, storage_scope, dtype}, Attrs(), 
{});
 }
 
 
TVM_REGISTER_GLOBAL("relax.op.memory.alloc_storage").set_body_typed(MakeAllocStorage);
@@ -331,14 +328,14 @@ StructInfo InferStructInfoMemAllocTensor(const Call& 
call, const BlockBuilder& c
 RELAY_REGISTER_OP("relax.memory.alloc_tensor")
     .set_num_inputs(4)
     .add_argument("storage", "Expr", "The storage to allocate the tensor to.")
-    .add_argument("offset", "int", "Storage offset to allocate the tensor.")
+    .add_argument("offset", "PrimValue", "Storage offset to allocate the 
tensor.")
     .add_argument("shape", "Expr", "The shape of the tensor to allocate.")
-    .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.")
+    .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to 
allocate.")
     .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoMemAllocTensor);
 
-Expr MakeMemAllocTensor(Expr storage, int offset, Expr shape, DataType dtype) {
+Expr MakeMemAllocTensor(Expr storage, PrimValue offset, Expr shape, 
DataTypeImm dtype) {
   static const Op& op = Op::Get("relax.memory.alloc_tensor");
-  return Call(op, {storage, PrimValue::Int64(offset), shape, 
DataTypeImm(dtype)}, Attrs(), {});
+  return Call(op, {storage, offset, shape, dtype}, Attrs(), {});
 }
 
 
TVM_REGISTER_GLOBAL("relax.op.memory.alloc_tensor").set_body_typed(MakeMemAllocTensor);
@@ -376,15 +373,15 @@ 
TVM_REGISTER_GLOBAL("relax.op.memory.kill_tensor").set_body_typed(MakeMemKillTen
 RELAY_REGISTER_OP("relax.vm.alloc_storage")
     .set_num_inputs(3)
     .add_argument("size", "Expr", "The size of the storage to allocate.")
-    .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.")
-    .add_argument("runtime_device_index", "int64_t",
+    .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to 
allocate.")
+    .add_argument("runtime_device_index", "PrimValue",
                   "The device index indicating on which device the tensor is "
                   "to be allocated at runtime.")
     .set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo);
 
-Expr MakeVMAllocStorage(Expr size, int64_t runtime_device_index, DataType 
dtype) {
+Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm 
dtype) {
   static const Op& op = Op::Get("relax.vm.alloc_storage");
-  return Call(op, {size, PrimValue::Int64(runtime_device_index), 
DataTypeImm(dtype)}, Attrs(), {});
+  return Call(op, {size, runtime_device_index, dtype}, Attrs(), {});
 }
 
 
TVM_REGISTER_GLOBAL("relax.op.vm.alloc_storage").set_body_typed(MakeVMAllocStorage);
@@ -408,14 +405,14 @@ StructInfo InferStructInfoVMAllocTensor(const Call& call, 
const BlockBuilder& ct
 RELAY_REGISTER_OP("relax.vm.alloc_tensor")
     .set_num_inputs(4)
     .add_argument("storage", "Expr", "The storage to allocate the tensor to.")
-    .add_argument("offset", "int", "Storage offset to allocate the tensor.")
+    .add_argument("offset", "PrimValue", "Storage offset to allocate the 
tensor.")
     .add_argument("shape", "Expr", "The shape of the tensor to allocate.")
-    .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.")
+    .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to 
allocate.")
     .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoVMAllocTensor);
 
-Expr MakeVMAllocTensor(Expr storage, int offset, Expr shape, DataType dtype) {
+Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm 
dtype) {
   static const Op& op = Op::Get("relax.vm.alloc_tensor");
-  return Call(op, {storage, PrimValue::Int64(offset), shape, 
DataTypeImm(dtype)}, Attrs(), {});
+  return Call(op, {storage, offset, shape, dtype}, Attrs(), {});
 }
 
 
TVM_REGISTER_GLOBAL("relax.op.vm.alloc_tensor").set_body_typed(MakeVMAllocTensor);
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index b885697c73..b4bb517859 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1065,20 +1065,35 @@ def test_arith_operators():
     _check(foo, bb.get()["foo"])
 
 
-# TODO(relax-team): enable this when vm ops are ready
[email protected]
-def test_vm_ops():
+def test_memory_ops():
     @R.function
     def foo(x: R.Tensor(("m", "n"), dtype="float32")):
         m = T.int64()
         n = T.int64()
-        storage = R.vm.alloc_storage(R.shape([4 * m * n]), dtype="float32", 
runtime_device_index=0)
-        alloc = R.vm.alloc_tensor(storage, shape=R.shape([m, n]), offset=0, 
dtype="float32")
+        storage = R.memory.alloc_storage(
+            R.shape([4 * m * n]), virtual_device_index=0, 
storage_scope="global", dtype="float32"
+        )
+        alloc = R.memory.alloc_tensor(storage, offset=0, shape=R.shape([m, 
n]), dtype="float32")
         tensor = R.builtin.alloc_tensor(R.shape([m, n]), dtype="float32", 
runtime_device_index=0)
-        _ = R.vm.call_tir_dyn("te_func", (x, tensor, (m, n)))
         gv = tensor
         return alloc, gv
 
+    _check(foo)
+
+
+def test_vm_ops():
+    @R.function
+    def foo(x: R.Tensor(("m", "n"), dtype="float32")):
+        m = T.int64()
+        n = T.int64()
+        storage = R.vm.alloc_storage(R.shape([4 * m * n]), 
runtime_device_index=0, dtype="float32")
+        alloc = R.vm.alloc_tensor(storage, offset=0, shape=R.shape([m, n]), 
dtype="float32")
+        tensor = R.builtin.alloc_tensor(R.shape([m, n]), dtype="float32", 
runtime_device_index=0)
+        tir_dym = R.vm.call_tir_dyn("te_func", (x, tensor, R.ShapeExpr((m, 
n))))
+        return alloc, tir_dym
+
+    _check(foo)
+
 
 def test_prim_value():
     @R.function

Reply via email to