This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new e571dc9262 [Unity] Add memory scope and nd allocation support in 
allocators (#15178)
e571dc9262 is described below

commit e571dc92627862f0274a38b4b53171802e027613
Author: Anirudh Sundar Subramaniam <[email protected]>
AuthorDate: Sat Jul 1 19:51:43 2023 +0530

    [Unity] Add memory scope and nd allocation support in allocators (#15178)
    
    This patch adds a new Allocator type called NDAllocator as discussed in
    [this 
discussion](https://discuss.tvm.apache.org/t/memory-scope-for-vm-alloc-storage-builtins/15172).
    
    The new allocator allows allocating to other memory scopes and allows
    nd-allocation.
---
 include/tvm/runtime/relax_vm/memory_manager.h      |  7 ++
 python/tvm/relax/op/vm/vm.py                       | 16 ++--
 python/tvm/runtime/relax_vm.py                     |  1 +
 src/relax/backend/vm/codegen_vm.cc                 |  2 +-
 src/relax/backend/vm/vm_builtin_lower.cc           | 14 ++--
 src/relax/op/op.cc                                 |  9 ++-
 src/relax/op/op_common.h                           |  3 +-
 src/runtime/relax_vm/builtin.cc                    | 11 +--
 src/runtime/relax_vm/memory_manager.cc             | 19 +++++
 src/runtime/relax_vm/naive_allocator.h             | 21 +++++
 .../test_relax_2d_buffer_allocation.py             | 91 ++++++++++++++++++++++
 .../test_transform_static_plan_block_memory.py     |  6 +-
 tests/python/relax/test_tvmscript_parser.py        |  2 +-
 .../relax/test_vm_alloc_storage_with_scope.py      | 74 ++++++++++++++++++
 tests/python/relax/test_vm_codegen_only.py         |  6 +-
 tests/python/relax/test_vm_cuda_graph.py           |  6 +-
 tests/python/relax/test_vm_execbuilder.py          |  8 +-
 17 files changed, 262 insertions(+), 34 deletions(-)

diff --git a/include/tvm/runtime/relax_vm/memory_manager.h 
b/include/tvm/runtime/relax_vm/memory_manager.h
index 55952de3f8..ed939fb88f 100644
--- a/include/tvm/runtime/relax_vm/memory_manager.h
+++ b/include/tvm/runtime/relax_vm/memory_manager.h
@@ -71,6 +71,13 @@ class Allocator {
    *  \return A sized allocation in the form of a buffer.
    */
   virtual Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) 
= 0;
+  /*! \brief Allocate a buffer given a size, alignment and type.
+   *  \param shape The shape of allocated tensor.
+   *  \param dtype A type hint to the allocator.
+   *  \param mem_scope The memory scope of allocated tensor.
+   *  \return A sized allocation in the form of a buffer.
+   */
+  virtual Buffer Alloc(ShapeTuple shape, DLDataType dtype, String mem_scope);
   /*! \brief Free a buffer allocated by the allocator.
    *  \param buffer The buffer to free.
    */
diff --git a/python/tvm/relax/op/vm/vm.py b/python/tvm/relax/op/vm/vm.py
index fdb1d0f7d9..3ed6b29648 100644
--- a/python/tvm/relax/op/vm/vm.py
+++ b/python/tvm/relax/op/vm/vm.py
@@ -17,23 +17,24 @@
 
 from typing import Union
 from . import _ffi_api
-from ...expr import Expr, Call, PrimValue, DataTypeImm, Tuple
+from ...expr import Expr, Call, PrimValue, DataTypeImm, Tuple, StringImm
 from ...utils import args_converter
 
 
 @args_converter.auto
 def alloc_storage(
-    size: Expr,
+    shape: Expr,
     runtime_device_index: Union[int, Expr],
     dtype: Union[str, Expr],
+    storage_scope: Union[str, StringImm] = "global",
 ) -> Call:
     """Construct a Call to allocate a storage with specific size,
     runtime_device_index, and dtype.
 
     Parameters
     ----------
-    size : Expr
-        The size of the storage to be allocated.
+    shape : Expr
+        The shape of the storage to be allocated.
 
     runtime_device_index : Union[int, Expr]
         The device index indicating on which device the tensor is to
@@ -42,6 +43,9 @@ def alloc_storage(
     dtype : Union[str, Expr]
         The datatype of the storage to be allocated.
 
+    storage_scope : Union[str, StringImm]
+        The storage scope of the storage to allocate. Default is global.
+
     Returns
     -------
     result : Call
@@ -49,9 +53,11 @@ def alloc_storage(
     """
     if isinstance(dtype, str):
         dtype = DataTypeImm(dtype)
+    if isinstance(storage_scope, str):
+        storage_scope = StringImm(storage_scope)
     if isinstance(runtime_device_index, int):
         runtime_device_index = PrimValue(runtime_device_index)
-    return _ffi_api.alloc_storage(size, runtime_device_index, dtype)  # type: 
ignore
+    return _ffi_api.alloc_storage(shape, runtime_device_index, dtype, 
storage_scope)  # type: ignore
 
 
 @args_converter.auto
diff --git a/python/tvm/runtime/relax_vm.py b/python/tvm/runtime/relax_vm.py
index d5ea93f988..3856a1a1c9 100644
--- a/python/tvm/runtime/relax_vm.py
+++ b/python/tvm/runtime/relax_vm.py
@@ -326,6 +326,7 @@ class VirtualMachine(object):
             If the result is a tuple, it returns a list of the fields.
             The fields are potentially also tuples, so these can be arbitrily 
nested.
         """
+
         # to deal with potentially nested tuples, we need to query for arity 
recursively
         def get_output_rec(func_name, *idx):
             arity = self._get_output_arity(func_name, *idx)
diff --git a/src/relax/backend/vm/codegen_vm.cc 
b/src/relax/backend/vm/codegen_vm.cc
index 3fbe246cd3..711b2d4ba5 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -333,7 +333,7 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const 
Expr&)> {
   }
 
   void EmitAllocStorage(const Call& call_node, RegName dst_reg) {
-    ICHECK_EQ(call_node->args.size(), 3);
+    ICHECK_EQ(call_node->args.size(), 4);
     // Handle args of the call
     std::vector<Instruction::Arg> args;
     args.push_back(Instruction::Arg::Register(Instruction::kVMRegister));
diff --git a/src/relax/backend/vm/vm_builtin_lower.cc 
b/src/relax/backend/vm/vm_builtin_lower.cc
index ad791424f6..6087c2bb25 100644
--- a/src/relax/backend/vm/vm_builtin_lower.cc
+++ b/src/relax/backend/vm/vm_builtin_lower.cc
@@ -85,9 +85,11 @@ class VMBuiltinLowerMutator : public ExprMutator {
     DataType dtype = output_dtype->value;
     Expr storage_size = ComputeStorageSize(output_shape, dtype);
     PrimValue runtime_device_index = Downcast<PrimValue>(call->args[2]);
-    Var storage = builder_->Emit(
-        Call(vm_alloc_storage_op_, {storage_size, runtime_device_index, 
output_dtype}, Attrs()),
-        "storage");
+    Var storage = builder_->Emit(Call(vm_alloc_storage_op_,
+                                      {storage_size, runtime_device_index,
+                                       DataTypeImm(DataType::UInt(8)), 
StringImm("global")},
+                                      Attrs()),
+                                 "storage");
     Expr shape = call->args[0];
     PrimValue offset = PrimValue::Int64(0);
     return Call(vm_alloc_tensor_op_, {storage, offset, shape, 
DataTypeImm(dtype)}, Attrs());
@@ -95,8 +97,10 @@ class VMBuiltinLowerMutator : public ExprMutator {
 
   Expr MakeMemAllocStorage(const Call& call) {
     PrimValue runtime_device_index = Downcast<PrimValue>(call->args[1]);
-    DataTypeImm output_dtype = Downcast<DataTypeImm>(call->args[3]);
-    return Call(vm_alloc_storage_op_, {call->args[0], runtime_device_index, 
output_dtype}, Attrs());
+    StringImm storage_scope = Downcast<StringImm>(call->args[2]);
+    DataTypeImm output_dtype = DataTypeImm(DataType::UInt(8));
+    return Call(vm_alloc_storage_op_,
+                {call->args[0], runtime_device_index, output_dtype, 
storage_scope}, Attrs());
   }
 
   Expr MakeMemAllocTensor(const Call& call) {
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index f1fb5c52bd..3e3cbd6b7d 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -564,19 +564,22 @@ 
TVM_REGISTER_GLOBAL("relax.op.memory.kill_tensor").set_body_typed(MakeMemKillTen
 // vm alloc_storage
 
 RELAY_REGISTER_OP("relax.vm.alloc_storage")
-    .set_num_inputs(3)
+    .set_num_inputs(4)
     .add_argument("size", "Expr", "The size of the storage to allocate.")
     .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to 
allocate.")
     .add_argument("runtime_device_index", "PrimValue",
                   "The device index indicating on which device the tensor is "
                   "to be allocated at runtime.")
+    .add_argument("storage_scope", "StringImm",
+                  "The storage scope of the storage to allocate. Default is 
global.")
     .set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo)
     // memory allocation isn't considered a "visible effect" as far as purity 
is concerned
     .set_attr<Bool>("FPurity", Bool(true));
 
-Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm 
dtype) {
+Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm 
dtype,
+                        StringImm storage_scope) {
   static const Op& op = Op::Get("relax.vm.alloc_storage");
-  return Call(op, {size, runtime_device_index, dtype}, Attrs(), {});
+  return Call(op, {size, runtime_device_index, dtype, storage_scope}, Attrs(), 
{});
 }
 
 
TVM_REGISTER_GLOBAL("relax.op.vm.alloc_storage").set_body_typed(MakeVMAllocStorage);
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index a6b437111b..2ada5672b6 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -347,7 +347,8 @@ inline Optional<ShapeExpr> 
CheckNdimPerLayoutAndGetShape(const Call& call, const
   return NullOpt;
 }
 
-Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm 
dtype);
+Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm 
dtype,
+                        StringImm storage_scope = StringImm("global"));
 Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm 
dtype);
 
 }  // namespace relax
diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc
index 24550c83a6..86f0152ce7 100644
--- a/src/runtime/relax_vm/builtin.cc
+++ b/src/runtime/relax_vm/builtin.cc
@@ -243,12 +243,10 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.check_func_info").set_body_typed(CheckFuncInfo);
 //-------------------------------------------------
 //  Storage management.
 //-------------------------------------------------
-Storage VMAllocStorage(void* ctx_ptr, ShapeTuple buffer_size, Index 
device_index,
-                       DLDataType dtype_hint) {
+Storage VMAllocStorage(void* ctx_ptr, ShapeTuple buffer_shape, Index 
device_index,
+                       DLDataType dtype_hint, String mem_scope) {
   VirtualMachine* vm = static_cast<VirtualMachine*>(ctx_ptr);
 
-  ICHECK_EQ(buffer_size.size(), 1);
-  int alignment = runtime::kAllocAlignment;
   ICHECK_LT(device_index, vm->devices.size())
       << "The device index is out of VM physical devices list";
 
@@ -257,12 +255,11 @@ Storage VMAllocStorage(void* ctx_ptr, ShapeTuple 
buffer_size, Index device_index
     device_index = vm->devices.size() - 1;
   }
 
-  int64_t size_imm = buffer_size[0];
-
   auto storage_obj = runtime::SimpleObjAllocator().make_object<StorageObj>();
   auto* alloc = vm->allocators[device_index];
   ICHECK(alloc) << "Did you forget to init the VirtualMachine with devices?";
-  storage_obj->buffer = alloc->Alloc(size_imm, alignment, dtype_hint);
+
+  storage_obj->buffer = alloc->Alloc(buffer_shape, dtype_hint, mem_scope);
   Storage storage(storage_obj);
   return storage;
 }
diff --git a/src/runtime/relax_vm/memory_manager.cc 
b/src/runtime/relax_vm/memory_manager.cc
index 2391bdc284..04ea3afdee 100644
--- a/src/runtime/relax_vm/memory_manager.cc
+++ b/src/runtime/relax_vm/memory_manager.cc
@@ -176,6 +176,25 @@ void MemoryManager::Clear() {
   m->allocators_.clear();
 }
 
+Buffer Allocator::Alloc(ShapeTuple shape, DLDataType dtype, String mem_scope) {
+  ICHECK_EQ(shape.size(), 1) << "Allocator of type (" << type_
+                             << ") does not support nD allocation. Please use 
allocator type ("
+                             << AllocatorType::kNaive << ")";
+  CHECK_EQ(mem_scope, "global") << "Allocator of type (" << type_
+                                << ") does not support memory scope " << 
mem_scope
+                                << ". Please use allocator type (" << 
AllocatorType::kNaive << ")";
+
+  DLTensor temp;
+  temp.ndim = shape.size();
+  temp.dtype = dtype;
+  temp.shape = const_cast<int64_t*>(shape.data());
+  temp.strides = nullptr;
+  temp.byte_offset = 0;
+  size_t nbytes = GetDataSize(temp);
+
+  return Alloc(nbytes, runtime::kAllocAlignment, dtype);
+}
+
 runtime::NDArray Allocator::Empty(ShapeTuple shape, DLDataType dtype, DLDevice 
dev) {
   VerifyDataType(dtype);
   runtime::NDArray::Container* container =
diff --git a/src/runtime/relax_vm/naive_allocator.h 
b/src/runtime/relax_vm/naive_allocator.h
index 843a559602..dde4a22066 100644
--- a/src/runtime/relax_vm/naive_allocator.h
+++ b/src/runtime/relax_vm/naive_allocator.h
@@ -47,6 +47,27 @@ class NaiveAllocator final : public Allocator {
     return buf;
   }
 
+  Buffer Alloc(ShapeTuple shape, DLDataType dtype, String mem_scope) override {
+    DLTensor temp;
+    temp.data = nullptr;
+    temp.device = device_;
+    temp.ndim = shape.size();
+    temp.dtype = dtype;
+    temp.shape = const_cast<int64_t*>(shape.data());
+    temp.strides = nullptr;
+    temp.byte_offset = 0;
+    size_t nbytes = GetDataSize(temp);
+
+    Buffer buf;
+    buf.device = device_;
+    buf.size = nbytes;
+    buf.data = runtime::DeviceAPI::Get(device_)->AllocDataSpace(device_, 
shape.size(), shape.data(),
+                                                                dtype, 
mem_scope);
+    used_memory_.fetch_add(nbytes, std::memory_order_relaxed);
+    DLOG(INFO) << "allocate " << nbytes << " B, used memory " << used_memory_ 
<< " B";
+    return buf;
+  }
+
   void Free(const Buffer& buffer) override {
     runtime::DeviceAPI::Get(device_)->FreeDataSpace(buffer.device, 
buffer.data);
     used_memory_.fetch_sub(buffer.size, std::memory_order_relaxed);
diff --git 
a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py 
b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py
new file mode 100644
index 0000000000..40de28cca0
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py
@@ -0,0 +1,91 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Relax hexagon 2d VTCM allocation test."""
+
+import numpy as np
+
+import tvm
+import tvm.contrib.hexagon
+import tvm.testing
+from tvm import relax
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
+# pylint: disable=missing-docstring,no-self-argument,invalid-name
[email protected]_module
+class Module:
+    @T.prim_func
+    def add(
+        arg0: T.Buffer((2, 2), "float32"),
+        arg1: T.Buffer((2, 2), "float32"),
+        output: T.Buffer((2, 2), "float32"),
+    ):
+        T.func_attr({"operator_name": "relax.add"})
+        for ax0 in range(2):
+            for ax1 in range(2):
+                with T.block("T_add"):
+                    v_ax0 = T.axis.spatial(2, ax0)
+                    v_ax1 = T.axis.spatial(2, ax1)
+                    T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+                    T.writes(output[v_ax0, v_ax1])
+                    output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, 
v_ax1]
+
+    @R.function
+    def main(x: R.Tensor((2, 2), dtype="float32")):
+        cls = Module
+        # Try allocating 2d storage (2,2) in global.vtcm scope with nd 
allocator
+        storage = R.vm.alloc_storage(
+            R.shape([2, 2]), runtime_device_index=0, dtype="float32", 
storage_scope="global.vtcm"
+        )
+        alloc = R.vm.alloc_tensor(storage, offset=0, shape=R.shape([2, 2]), 
dtype="float32")
+        _: R.Tuple = cls.add(x, x, alloc)
+        out: R.Tensor((2, 2), dtype="float32") = alloc
+        storage2 = R.vm.alloc_storage(R.shape([4 * 2 * 2]), 
runtime_device_index=0, dtype="uint8")
+        alloc2 = R.vm.alloc_tensor(storage2, offset=0, shape=R.shape([2, 2]), 
dtype="float32")
+        _1: R.Tuple = cls.add(out, x, alloc2)
+        out2: R.Tensor((2, 2), dtype="float32") = alloc2
+        return out2
+
+
+# pylint: enable=missing-docstring,no-self-argument,invalid-name
+def test_alloc_storage_with_scope_global(hexagon_launcher):
+    """
+    Test 2d allocation to global.vtcm memory scope in a Relax Function
+    """
+    arg0 = np.random.uniform(size=(2, 2)).astype(np.float32)
+
+    output_ref = arg0 + arg0 + arg0
+
+    mod = Module
+
+    target_hexagon = tvm.target.hexagon("v69", vtcm_capacity=4 * 2**20)
+    target = tvm.target.Target(target_hexagon, host=target_hexagon)
+    with tvm.transform.PassContext(opt_level=3):
+        lib = relax.build(mod, target, exec_mode="compiled")
+
+    with hexagon_launcher.create_session() as session:
+        dev = session.device
+        vm_mod = session.get_executor_from_factory(lib)
+        # This is the important line which tests nd allocator
+        vm_rt = relax.VirtualMachine(vm_mod, dev, memory_cfg="naive")
+        x = tvm.nd.array(arg0, dev)
+        vm_rt.set_input("main", x)
+        vm_rt.invoke_stateful("main")
+        hexagon_output = vm_rt.get_outputs("main").numpy()
+    tvm.testing.assert_allclose(output_ref, hexagon_output)
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py 
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 0f59278a8e..147fc049c9 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -159,12 +159,12 @@ def test_basic():
         def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
             R.func_attr({"relax.force_pure": True})
             cls = ExpectedLowered
-            storage: R.Object = R.vm.alloc_storage(R.shape([32]), 
R.prim_value(0), R.dtype("float32"))
+            storage: R.Object = R.vm.alloc_storage(R.shape([32]), 
R.prim_value(0), R.dtype("uint8"))
             alloc: R.Tensor((2, 4), dtype="float32") = 
R.vm.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]), R.dtype("float32"))
             _: R.Tuple = cls.exp(x, alloc)
             lv: R.Tensor((2, 4), dtype="float32") = alloc
             lv1: R.Tensor((8,), dtype="float32") = 
R.call_packed("vm.builtin.reshape", lv, R.shape([8]), 
sinfo_args=(R.Tensor((8,), dtype="float32"),))
-            storage1: R.Object = R.vm.alloc_storage(R.shape([40]), 
R.prim_value(0), R.dtype("float32"))
+            storage1: R.Object = R.vm.alloc_storage(R.shape([40]), 
R.prim_value(0), R.dtype("uint8"))
             alloc1: R.Tensor((8,), dtype="float32") = 
R.vm.alloc_tensor(storage1, R.prim_value(0), R.shape([8]), R.dtype("float32"))
             _1: R.Tuple = cls.relu(lv1, alloc1)
             __1: R.Tuple = R.vm.kill_object(alloc)
@@ -178,7 +178,7 @@ def test_basic():
             _3: R.Tuple = cls.pad(lv3, alloc3)
             _3_1: R.Tuple = R.vm.kill_object(alloc2)
             lv4: R.Tensor((10,), dtype="float32") = alloc3
-            storage_1: R.Object = R.vm.alloc_storage(R.shape([40]), 
R.prim_value(0), R.dtype("float32"))
+            storage_1: R.Object = R.vm.alloc_storage(R.shape([40]), 
R.prim_value(0), R.dtype("uint8"))
             alloc4: R.Tensor((10,), dtype="float32") = 
R.vm.alloc_tensor(storage_1, R.prim_value(0), R.shape([10]), R.dtype("float32"))
             _4: R.Tuple = cls.log(lv4, alloc4)
             _4_1: R.Tuple = R.vm.kill_object(alloc3)
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index 564fe04692..c9aa16b9b2 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1208,7 +1208,7 @@ def test_vm_ops():
     def foo(x: R.Tensor(("m", "n"), dtype="float32")):
         m = T.int64()
         n = T.int64()
-        storage = R.vm.alloc_storage(R.shape([4 * m * n]), 
runtime_device_index=0, dtype="float32")
+        storage = R.vm.alloc_storage(R.shape([4 * m * n]), 
runtime_device_index=0, dtype="uint8")
         alloc = R.vm.alloc_tensor(storage, offset=0, shape=R.shape([m, n]), 
dtype="float32")
         tensor = R.builtin.alloc_tensor(R.shape([m, n]), dtype="float32", 
runtime_device_index=0)
         tir_dym = R.vm.call_tir_dyn("te_func", (x, tensor, R.ShapeExpr((m, 
n))))
diff --git a/tests/python/relax/test_vm_alloc_storage_with_scope.py 
b/tests/python/relax/test_vm_alloc_storage_with_scope.py
new file mode 100644
index 0000000000..ca1802b1f5
--- /dev/null
+++ b/tests/python/relax/test_vm_alloc_storage_with_scope.py
@@ -0,0 +1,74 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test Naive allocator with memory scope for Relax VM"""
+
+import numpy as np
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
[email protected]_module
+class Module:
+    @T.prim_func
+    def add(
+        arg0: T.Buffer((2, 2), "float32"),
+        arg1: T.Buffer((2, 2), "float32"),
+        output: T.Buffer((2, 2), "float32"),
+    ):
+        T.func_attr({"operator_name": "relax.add"})
+        for ax0 in range(2):
+            for ax1 in range(2):
+                with T.block("T_add"):
+                    v_ax0 = T.axis.spatial(2, ax0)
+                    v_ax1 = T.axis.spatial(2, ax1)
+                    T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+                    T.writes(output[v_ax0, v_ax1])
+                    output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, 
v_ax1]
+
+    @R.function
+    def main(x: R.Tensor((2, 2), dtype="float32")):
+        cls = Module
+        storage = R.vm.alloc_storage(
+            R.shape([2 * 2]), runtime_device_index=0, dtype="float32", 
storage_scope="global"
+        )
+        alloc = R.vm.alloc_tensor(storage, offset=0, shape=R.shape([2, 2]), 
dtype="float32")
+        _: R.Tuple = cls.add(x, x, alloc)
+        out: R.Tensor((2, 2), dtype="float32") = alloc
+        return out
+
+
+def test_alloc_storage_with_scope_global():
+    arg0 = np.random.uniform(size=(2, 2)).astype(np.float32)
+    output_ref = arg0 + arg0
+    mod = Module
+    target = "llvm"
+    with tvm.transform.PassContext(opt_level=3):
+        lib = relax.build(mod, target, exec_mode="compiled")
+
+    dev = tvm.cpu()
+    # This is the important line which tests nd allocator
+    vm_rt = relax.VirtualMachine(lib, dev, memory_cfg="naive")
+    x = tvm.nd.array(arg0, dev)
+    vm_rt.set_input("main", x)
+    vm_rt.invoke_stateful("main")
+    output = vm_rt.get_outputs("main").numpy()
+    tvm.testing.assert_allclose(output_ref, output)
diff --git a/tests/python/relax/test_vm_codegen_only.py 
b/tests/python/relax/test_vm_codegen_only.py
index d3a047b62b..ffa9837d02 100644
--- a/tests/python/relax/test_vm_codegen_only.py
+++ b/tests/python/relax/test_vm_codegen_only.py
@@ -360,9 +360,7 @@ def test_vm_kill_object(exec_mode):
         def main() -> R.Tensor((4,), dtype="float32"):
             R.func_attr({"global_symbol": "main"})
             cls = TestKillObject
-            storage: R.Object = R.vm.alloc_storage(
-                R.shape([16]), R.prim_value(0), R.dtype("float32")
-            )
+            storage: R.Object = R.vm.alloc_storage(R.shape([16]), 
R.prim_value(0), R.dtype("uint8"))
             alloc: R.Tensor((4,), dtype="float32") = R.vm.alloc_tensor(
                 storage, R.prim_value(0), R.shape([4]), R.dtype("float32")
             )
@@ -376,7 +374,7 @@ def test_vm_kill_object(exec_mode):
             _1_1: R.Tuple = R.vm.kill_object(alloc1)
             y: R.Tensor((4,), dtype="float32") = alloc1
             storage_1: R.Object = R.vm.alloc_storage(
-                R.shape([16]), R.prim_value(0), R.dtype("float32")
+                R.shape([16]), R.prim_value(0), R.dtype("uint8")
             )
             alloc2: R.Tensor((4,), dtype="float32") = R.vm.alloc_tensor(
                 storage_1, R.prim_value(0), R.shape([4]), R.dtype("float32")
diff --git a/tests/python/relax/test_vm_cuda_graph.py 
b/tests/python/relax/test_vm_cuda_graph.py
index bd4b3fe90f..8406b9df15 100644
--- a/tests/python/relax/test_vm_cuda_graph.py
+++ b/tests/python/relax/test_vm_cuda_graph.py
@@ -38,7 +38,7 @@ class Module:
         storage1: R.Object = gv[1]
         gv1: R.Tuple(R.Tensor(dtype="float32"), R.Object, R.Object) = (alloc, 
storage1, storage)
         gv2: R.Tuple(R.Tensor((16, 16), dtype="float32")) = 
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", 
(cls.cuda_graph_capture, gv1, R.prim_value(0)), 
sinfo_args=(R.Tuple(R.Tensor((16, 16), dtype="float32")),))
-        storage2: R.Object = R.vm.alloc_storage(R.shape((1024,)), 
R.prim_value(0), R.dtype("float32"))
+        storage2: R.Object = R.vm.alloc_storage(R.shape((1024,)), 
R.prim_value(0), R.dtype("uint8"))
         alloc3: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage2, 
R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
         lv4: R.Tensor((16, 16), dtype="float32") = gv2[0]
         _3: R.Tuple = cls.add(lv4, alloc3)
@@ -58,8 +58,8 @@ class Module:
     @R.function
     def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
         R.func_attr({"global_symbol": "cuda_graph_alloc"})
-        storage: R.Object = R.vm.alloc_storage(R.shape((1024,)), 
R.prim_value(0), R.dtype("float32"))
-        storage1: R.Object = R.vm.alloc_storage(R.shape((1024,)), 
R.prim_value(0), R.dtype("float32"))
+        storage: R.Object = R.vm.alloc_storage(R.shape((1024,)), 
R.prim_value(0), R.dtype("uint8"))
+        storage1: R.Object = R.vm.alloc_storage(R.shape((1024,)), 
R.prim_value(0), R.dtype("uint8"))
         gv: R.Tuple(R.Object, R.Object) = (storage, storage1)
         return gv
 
diff --git a/tests/python/relax/test_vm_execbuilder.py 
b/tests/python/relax/test_vm_execbuilder.py
index 9a7cd0c879..5d9491dad7 100644
--- a/tests/python/relax/test_vm_execbuilder.py
+++ b/tests/python/relax/test_vm_execbuilder.py
@@ -169,7 +169,13 @@ def test_vm_storage():
     with ib.function("main", num_inputs=0):
         ib.emit_call(
             "vm.builtin.alloc_storage",
-            args=[ib.vm_state(), (24,), ib.convert_constant(0), dtype],
+            args=[
+                ib.vm_state(),
+                (24,),
+                ib.convert_constant(0),
+                dtype,
+                ib.convert_constant("global"),
+            ],
             dst=ib.r(1),
         )
         ib.emit_call(

Reply via email to