This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new f9b07ae9ed [Unity][VM] `kill_tensor` and `kill_storage` releasing 
NDArray in VM at runtime (#14753)
f9b07ae9ed is described below

commit f9b07ae9ed9be366e482ae30e9b4b27871792c66
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed May 3 09:55:33 2023 -0400

    [Unity][VM] `kill_tensor` and `kill_storage` releasing NDArray in VM at 
runtime (#14753)
    
    This PR lowers `memory.kill_tensor` and `memory.kill_storage` ops to a
    runtime function "`kill_object`" that releases the corresponding NDArray
    of the "tensor" and "storage" in VM at runtime.
    
    Previously these two ops are just ignored and skipped in VMBuiltinLower,
    and the "kills" didn't take effect. This PR supports the "real kill" so
    that the VM will release the NDArrays upon kill, and get better memory
    management for us.
---
 include/tvm/runtime/relax_vm/bytecode.h            |  4 +-
 python/tvm/relax/op/vm/vm.py                       | 17 ++++++
 src/relax/backend/vm/codegen_vm.cc                 | 16 +++++-
 src/relax/backend/vm/codegen_vm_tir.cc             | 22 ++++++++
 src/relax/backend/vm/vm_builtin_lower.cc           | 17 +++---
 src/relax/op/op.cc                                 | 14 +++++
 src/runtime/relax_vm/builtin.cc                    |  7 ++-
 src/runtime/relax_vm/vm.cc                         |  6 +-
 tests/python/relax/test_op_misc.py                 |  8 +++
 .../test_transform_static_plan_block_memory.py     | 59 ++++++++++++++++++++
 tests/python/relax/test_vm_codegen_only.py         | 65 ++++++++++++++++++++++
 11 files changed, 218 insertions(+), 17 deletions(-)

diff --git a/include/tvm/runtime/relax_vm/bytecode.h 
b/include/tvm/runtime/relax_vm/bytecode.h
index 91d1823258..fdafaac1e0 100644
--- a/include/tvm/runtime/relax_vm/bytecode.h
+++ b/include/tvm/runtime/relax_vm/bytecode.h
@@ -80,7 +80,7 @@ struct Instruction {
   static constexpr ExecWord kValueMaxLimit = (static_cast<ExecWord>(1) << 
(kValueBit - 1)) - 1;
   /*! \brief Minimum possible value, remove 1 slot to keep things symmetric. */
   static constexpr ExecWord kValueMinLimit = -kValueMaxLimit;
-  /*! \brief Begining of special register section. */
+  /*! \brief Beginning of special register section. */
   static constexpr RegName kBeginSpecialReg = static_cast<ExecWord>(1) << 54;
   /*! \brief Random magic number that represents void argument, indicate null 
value */
   static constexpr RegName kVoidRegister = kBeginSpecialReg + 0;
@@ -127,7 +127,7 @@ struct Instruction {
      */
     static Arg FuncIdx(Index index) { return Arg(ArgKind::kFuncIdx, index); }
     /*!
-     * \brief Get the kind of argument..
+     * \brief Get the kind of argument.
      * \return The kind of argument.
      */
     ArgKind kind() const {
diff --git a/python/tvm/relax/op/vm/vm.py b/python/tvm/relax/op/vm/vm.py
index a20407a4c9..fdb1d0f7d9 100644
--- a/python/tvm/relax/op/vm/vm.py
+++ b/python/tvm/relax/op/vm/vm.py
@@ -86,6 +86,23 @@ def alloc_tensor(
     return _ffi_api.alloc_tensor(storage, offset, shape, dtype)  # type: ignore
 
 
+def kill_object(obj: Expr) -> Call:
+    """Construct a Call to set the register corresponding to the input object 
to
+    null at runtime, in order to kill the input object.
+
+    Parameters
+    ----------
+    obj : Expr
+        The object to be killed.
+
+    Returns
+    -------
+    result : Call
+        CallNode that kills the input object.
+    """
+    return _ffi_api.kill_object(obj)  # type: ignore
+
+
 @args_converter.auto
 def call_tir_dyn(func: Expr, args: Tuple) -> Call:
     """Construct a Call to call_tir_dyn (invoke the given TIR PrimFunc)
diff --git a/src/relax/backend/vm/codegen_vm.cc 
b/src/relax/backend/vm/codegen_vm.cc
index b36b5ed4d6..09f21cf751 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -163,6 +163,8 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const 
Expr&)> {
         EmitAllocStorage(call, dst_reg);
       } else if (call_node->op == alloc_tensor_op_) {
         EmitAllocTensor(call, dst_reg);
+      } else if (call_node->op == kill_object_op_) {
+        dst_reg = EmitKillObject(call);
       } else {
         // every "normal" operator is lowered to a global var in the IRModule. 
The Attrs for those
         // ops are handled in a pass when lowering them to TIR.
@@ -352,6 +354,15 @@ class CodeGenVM : public 
ExprFunctor<Instruction::Arg(const Expr&)> {
     builder_->EmitCall("vm.builtin.alloc_tensor", args, dst_reg);
   }
 
+  RegName EmitKillObject(const Call& call_node) {
+    ICHECK_EQ(call_node->args.size(), 1);
+    Instruction::Arg arg = this->VisitExpr(call_node->args[0]);
+    ICHECK(arg.kind() == Instruction::ArgKind::kRegister);
+    RegName dst_reg = arg.value();
+    builder_->EmitCall("vm.builtin.null_value", {}, dst_reg);
+    return dst_reg;
+  }
+
   void EmitCallBuiltinWithCtx(const Call& call_node, RegName dst_reg) {
     std::vector<Instruction::Arg> args;
     args.push_back(Instruction::Arg::Register(Instruction::kVMRegister));
@@ -401,6 +412,7 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const 
Expr&)> {
   /*! \brief Cache ops that need to be frequently used later to reduce lookup 
overhead. */
   const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
   const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
+  const Op& kill_object_op_ = Op::Get("relax.vm.kill_object");
   const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
   const Op& null_value_op_ = Op::Get("relax.null_value");
 };
@@ -410,7 +422,7 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const 
Expr&)> {
  *        and add them to exec_builder.
  * \param exec_builder Builder to collect executables.
  * \param mod Input module.
- * \return Left over IRModule that may contain otehr functions.
+ * \return Left over IRModule that may contain other functions.
  */
 IRModule VMCodeGen(ExecBuilder exec_builder, IRModule mod) {
   return CodeGenVM::Run(exec_builder, mod);
@@ -419,7 +431,7 @@ IRModule VMCodeGen(ExecBuilder exec_builder, IRModule mod) {
 TVM_REGISTER_GLOBAL("relax.VMCodeGen").set_body_typed(VMCodeGen);
 
 /*!
- * \brief Link the libaries together.
+ * \brief Link the libraries together.
  */
 Module VMLink(ExecBuilder builder, Target target, Optional<Module> lib, 
Array<Module> ext_libs,
               Map<String, runtime::NDArray> params) {
diff --git a/src/relax/backend/vm/codegen_vm_tir.cc 
b/src/relax/backend/vm/codegen_vm_tir.cc
index 2f63a50d37..276632a917 100644
--- a/src/relax/backend/vm/codegen_vm_tir.cc
+++ b/src/relax/backend/vm/codegen_vm_tir.cc
@@ -230,6 +230,8 @@ class CodeGenVMTIR : public 
ExprFunctor<Optional<PrimExpr>(const Expr&)> {
         EmitAllocStorage(call, dst_reg);
       } else if (call_node->op == alloc_tensor_op_) {
         EmitAllocTensor(call, dst_reg);
+      } else if (call_node->op == kill_object_op_) {
+        dst_reg = EmitKillObject(call);
       } else {
         // every "normal" operator is lowered to a global var in the IRModule. 
The Attrs for those
         // ops are handled in a pass when lowering them to TIR.
@@ -404,6 +406,25 @@ class CodeGenVMTIR : public 
ExprFunctor<Optional<PrimExpr>(const Expr&)> {
     this->EmitCallPacked("vm.builtin.alloc_tensor", args, dst_reg);
   }
 
+  int64_t EmitKillObject(const Call& call_node) {
+    ICHECK_EQ(call_node->args.size(), 1);
+    PrimExpr arg = this->VisitExpr(call_node->args[0]).value();
+
+    // Check the arg is a register.
+    const auto* tir_call = arg.as<tir::CallNode>();
+    ICHECK(tir_call != nullptr);
+    ICHECK(tir_call->op == tir::builtin::anylist_getitem());
+    ICHECK(tir_call->args.size() == 2);
+    ICHECK(tir_call->args[0].same_as(reg_anylist_handle_));
+    const auto* p_dst_reg = tir_call->args[1].as<tir::IntImmNode>();
+    ICHECK(p_dst_reg != nullptr);
+    ICHECK(p_dst_reg->dtype == DataType::Int(32));
+
+    int64_t dst_reg = p_dst_reg->value;
+    this->EmitCallPacked("vm.builtin.null_value", {}, dst_reg);
+    return dst_reg;
+  }
+
   void EmitCallBuiltinWithCtx(const Call& call_node, int64_t dst_reg) {
     Array<PrimExpr> args;
     // if context is required, pass as first argument.
@@ -488,6 +509,7 @@ class CodeGenVMTIR : public 
ExprFunctor<Optional<PrimExpr>(const Expr&)> {
   /*! \brief Cache ops that need to be frequently used later to reduce lookup 
overhead. */
   const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
   const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
+  const Op& kill_object_op_ = Op::Get("relax.vm.kill_object");
   const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
   const Op& null_value_op_ = Op::Get("relax.null_value");
 };
diff --git a/src/relax/backend/vm/vm_builtin_lower.cc 
b/src/relax/backend/vm/vm_builtin_lower.cc
index 5bf4194997..ad791424f6 100644
--- a/src/relax/backend/vm/vm_builtin_lower.cc
+++ b/src/relax/backend/vm/vm_builtin_lower.cc
@@ -36,15 +36,6 @@ class VMBuiltinLowerMutator : public ExprMutator {
  public:
   using ExprMutator::VisitExpr_;
 
-  // A workaround to remove the CallNodes of killing tensors and storages.
-  void VisitBinding_(const VarBindingNode* binding) final {
-    const auto* call = binding->value.as<CallNode>();
-    if (call != nullptr && (call->op == mem_kill_storage_op_ || call->op == 
mem_kill_tensor_op_)) {
-      return;
-    }
-    ExprMutator::VisitBinding_(binding);
-  }
-
   Expr VisitExpr_(const CallNode* call_node) final {
     // post-order mutation
     Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
@@ -65,6 +56,8 @@ class VMBuiltinLowerMutator : public ExprMutator {
       return MakeMemAllocStorage(call);
     } else if (call->op == mem_alloc_tensor_op_) {
       return MakeMemAllocTensor(call);
+    } else if (call->op == mem_kill_storage_op_ || call->op == 
mem_kill_tensor_op_) {
+      return MakeMemKillObject(call);
     } else {
       return call;
     }
@@ -112,6 +105,11 @@ class VMBuiltinLowerMutator : public ExprMutator {
     return Call(vm_alloc_tensor_op_, {call->args[0], offset, call->args[2], 
dtype}, Attrs());
   }
 
+  Expr MakeMemKillObject(const Call& call) {
+    ICHECK_EQ(call->args.size(), 1);
+    return Call(vm_kill_object_op_, {call->args[0]}, Attrs());
+  }
+
   Expr CallTIRDyn(const Call& call_node) {
     ICHECK(call_node->args.size() == 2);
     ICHECK(call_node->args[0]->IsInstance<GlobalVarNode>());
@@ -206,6 +204,7 @@ class VMBuiltinLowerMutator : public ExprMutator {
   // functions to lower to
   const Op& vm_alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
   const Op& vm_alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
+  const Op& vm_kill_object_op_ = Op::Get("relax.vm.kill_object");
   // Function to compute allocated shape.
   const ExternFunc 
builtin_compute_alloc_shape_{"vm.builtin.compute_alloc_shape"};
   const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"};
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index 9d331e41dd..f2106f1550 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -526,6 +526,20 @@ Expr MakeVMAllocTensor(Expr storage, PrimValue offset, 
Expr shape, DataTypeImm d
 
 
TVM_REGISTER_GLOBAL("relax.op.vm.alloc_tensor").set_body_typed(MakeVMAllocTensor);
 
+// vm kill_object
+
+TVM_REGISTER_OP("relax.vm.kill_object")
+    .set_num_inputs(1)
+    .add_argument("obj", "Expr", "The object to be killed.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo);
+
+Expr MakeVMKillObject(Expr obj) {
+  static const Op& op = Op::Get("relax.vm.kill_object");
+  return Call(op, {std::move(obj)}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.vm.kill_object").set_body_typed(MakeVMKillObject);
+
 // vm call_tir_dyn
 
 RELAY_REGISTER_OP("relax.vm.call_tir_dyn")
diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc
index af0963bf41..24550c83a6 100644
--- a/src/runtime/relax_vm/builtin.cc
+++ b/src/runtime/relax_vm/builtin.cc
@@ -53,7 +53,7 @@ NDArray AllocShapeHeap(void* ctx_ptr, int64_t size) {
   VirtualMachine* vm = static_cast<VirtualMachine*>(ctx_ptr);
   // use host allocator, which is always last element.
   size_t host_device_index = vm->devices.size() - 1;
-  // specialy handle hexagon on-device RT.
+  // specially handle hexagon on-device RT.
   // TODO(relax-team): visit and consider other possible choices.
   if (vm->devices[0].device_type == kDLHexagon) {
     host_device_index = 0;
@@ -325,6 +325,11 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.reshape").set_body_typed([](NDArray data, ShapeT
   return data.CreateView(new_shape, data->dtype);
 });
 
+TVM_REGISTER_GLOBAL("vm.builtin.null_value").set_body([](TVMArgs args, 
TVMRetValue* rv) {
+  CHECK_EQ(args.size(), 0);
+  *rv = nullptr;
+});
+
 /*!
  * \brief Load the scalar value in cond and return the result value.
  * \param cond The condition
diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc
index 2e6c341213..01497fdd7c 100644
--- a/src/runtime/relax_vm/vm.cc
+++ b/src/runtime/relax_vm/vm.cc
@@ -327,7 +327,7 @@ class VirtualMachineImpl : public VirtualMachine {
     } else {
       ICHECK_EQ(reg, Instruction::kVMRegister);
       // per convention, ctx ptr must be VirtualMachine* casted to void.
-      // this and VirtualMachine* may or maynot be the same
+      // this and VirtualMachine* may or may not be the same
       // do first cast to VirtualMachine* then to void*
       ret = static_cast<void*>(static_cast<VirtualMachine*>(this));
     }
@@ -870,7 +870,7 @@ void VirtualMachineImpl::RunLoop() {
   VMFrame* curr_frame = frames_.back().get();
 
   while (true) {
-    ICHECK_LT(static_cast<size_t>(pc_), exec_->instr_offset.size()) << "run 
into invalide section";
+    ICHECK_LT(static_cast<size_t>(pc_), exec_->instr_offset.size()) << "run 
into invalid section";
     Instruction instr = exec_->GetInstruction(pc_);
     switch (instr.op) {
       case Opcode::Call: {
@@ -1005,7 +1005,7 @@ class VirtualMachineProfiler : public VirtualMachineImpl {
       std::unordered_map<std::string, ObjectRef> metrics;
       metrics["Argument Shapes"] = profiling::ShapeString(arrs);
 
-      // If a sutiable device is found, enable profiling.
+      // If a suitable device is found, enable profiling.
       if (dev) {
         profiling = true;
         prof_->StartCall(f_name, *dev, metrics);
diff --git a/tests/python/relax/test_op_misc.py 
b/tests/python/relax/test_op_misc.py
index 87c2a58a82..2eeaa15208 100644
--- a/tests/python/relax/test_op_misc.py
+++ b/tests/python/relax/test_op_misc.py
@@ -112,6 +112,14 @@ def test_vm_alloc_tensor_infer_struct_info():
     tvm.ir.assert_structural_equal(ret.struct_info, R.Tensor(dtype="float32", 
ndim=3))
 
 
+def test_vm_kill_object():
+    bb = rx.BlockBuilder()
+    storage = rx.Var("storage", rx.TensorStructInfo(dtype="float32"))
+    kill = rx.op.vm.kill_object(storage)
+    ret = bb.normalize(kill)
+    tvm.ir.assert_structural_equal(ret.struct_info, R.Tuple([]))
+
+
 def test_builtin_stop_lift_params():
     bb = rx.BlockBuilder()
     x = rx.Var("x", rx.TensorStructInfo(shape=[4, 5], dtype="float32"))
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py 
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 85036db347..e669f012f7 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -125,10 +125,69 @@ def test_basic():
             _11: R.Tuple() = R.memory.kill_storage(storage)
             _10: R.Tuple() = R.memory.kill_storage(storage1)
             return gv5
+
+    @I.ir_module
+    class ExpectedLowered:
+        @T.prim_func
+        def add(rxplaceholder: T.Buffer((T.int64(8),), "float32"), 
rxplaceholder_1: T.Buffer((), "float32"), T_add: T.Buffer((T.int64(8),), 
"float32")):
+            T.evaluate(0)
+
+        @T.prim_func
+        def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), 
compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):
+            T.evaluate(0)
+
+        @T.prim_func
+        def log(rxplaceholder: T.Buffer((T.int64(10),), "float32"), compute: 
T.Buffer((T.int64(10),), "float32")):
+            T.evaluate(0)
+
+        @T.prim_func
+        def pad(rxplaceholder: T.Buffer((T.int64(8),), "float32"), PadInput: 
T.Buffer((T.int64(10),), "float32")):
+            T.evaluate(0)
+
+        @T.prim_func
+        def relu(rxplaceholder: T.Buffer((T.int64(8),), "float32"), compute: 
T.Buffer((T.int64(8),), "float32")):
+            T.evaluate(0)
+
+        @T.prim_func
+        def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), 
"float32"), T_reshape: T.Buffer((T.int64(8),), "float32")):
+            T.evaluate(0)
+
+        @R.function
+        def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
+            cls = ExpectedLowered
+            storage: R.Object = R.vm.alloc_storage(R.shape([32]), 
R.prim_value(0), R.dtype("float32"))
+            alloc: R.Tensor((2, 4), dtype="float32") = 
R.vm.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]), R.dtype("float32"))
+            _: R.Tuple = cls.exp(x, alloc)
+            lv: R.Tensor((2, 4), dtype="float32") = alloc
+            lv1: R.Tensor((8,), dtype="float32") = 
R.call_packed("vm.builtin.reshape", lv, R.shape([8]), 
sinfo_args=(R.Tensor((8,), dtype="float32"),))
+            storage1: R.Object = R.vm.alloc_storage(R.shape([40]), 
R.prim_value(0), R.dtype("float32"))
+            alloc1: R.Tensor((8,), dtype="float32") = 
R.vm.alloc_tensor(storage1, R.prim_value(0), R.shape([8]), R.dtype("float32"))
+            _1: R.Tuple = cls.relu(lv1, alloc1)
+            __1: R.Tuple = R.vm.kill_object(alloc)
+            _1_1: R.Tuple = R.vm.kill_object(lv1)
+            lv2: R.Tensor((8,), dtype="float32") = alloc1
+            alloc2: R.Tensor((8,), dtype="float32") = 
R.vm.alloc_tensor(storage, R.prim_value(0), R.shape([8]), R.dtype("float32"))
+            _2: R.Tuple = cls.add(lv2, R.const(1, "float32"), alloc2)
+            _2_1: R.Tuple = R.vm.kill_object(alloc1)
+            lv3: R.Tensor((8,), dtype="float32") = alloc2
+            alloc3: R.Tensor((10,), dtype="float32") = 
R.vm.alloc_tensor(storage1, R.prim_value(0), R.shape([10]), R.dtype("float32"))
+            _3: R.Tuple = cls.pad(lv3, alloc3)
+            _3_1: R.Tuple = R.vm.kill_object(alloc2)
+            lv4: R.Tensor((10,), dtype="float32") = alloc3
+            storage_1: R.Object = R.vm.alloc_storage(R.shape([40]), 
R.prim_value(0), R.dtype("float32"))
+            alloc4: R.Tensor((10,), dtype="float32") = 
R.vm.alloc_tensor(storage_1, R.prim_value(0), R.shape([10]), R.dtype("float32"))
+            _4: R.Tuple = cls.log(lv4, alloc4)
+            _4_1: R.Tuple = R.vm.kill_object(alloc3)
+            gv: R.Tensor((10,), dtype="float32") = alloc4
+            _5: R.Tuple = R.vm.kill_object(storage)
+            _6: R.Tuple = R.vm.kill_object(storage1)
+            return gv
     # fmt: on
 
     mod = relax.transform.StaticPlanBlockMemory()(Module)
     tvm.ir.assert_structural_equal(mod, Expected)
+    mod = relax.transform.VMBuiltinLower()(mod)
+    tvm.ir.assert_structural_equal(mod, ExpectedLowered)
 
 
 def test_different_dtype():
diff --git a/tests/python/relax/test_vm_codegen_only.py 
b/tests/python/relax/test_vm_codegen_only.py
index b9904429f3..d3a047b62b 100644
--- a/tests/python/relax/test_vm_codegen_only.py
+++ b/tests/python/relax/test_vm_codegen_only.py
@@ -25,6 +25,7 @@ import tvm.testing
 from tvm import relax
 from tvm.relax.testing.runtime_builtin import MakeShapeCode, MatchShapeCode
 from tvm.relax.testing.vm import check_saved_func
+from tvm.script import ir as I
 from tvm.script import relax as R
 from tvm.script import tir as T
 
@@ -331,5 +332,69 @@ def test_vm_builtin_reshape(exec_mode):
     tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-7, atol=1e-7)
 
 
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_kill_object(exec_mode):
+    @I.ir_module
+    class TestKillObject:
+        @T.prim_func
+        def full(T_full: T.Buffer((T.int64(4),), "float32")):
+            T.func_attr({"global_symbol": "full", "tir.noalias": T.bool(True)})
+            for ax0 in range(T.int64(4)):
+                with T.block("T_full"):
+                    v_ax0 = T.axis.spatial(T.int64(4), ax0)
+                    T.reads()
+                    T.writes(T_full[v_ax0])
+                    T_full[v_ax0] = T.float32(0)
+
+        @T.prim_func
+        def full1(T_full: T.Buffer((T.int64(4),), "float32")):
+            T.func_attr({"global_symbol": "full1", "tir.noalias": 
T.bool(True)})
+            for ax0 in range(T.int64(4)):
+                with T.block("T_full"):
+                    v_ax0 = T.axis.spatial(T.int64(4), ax0)
+                    T.reads()
+                    T.writes(T_full[v_ax0])
+                    T_full[v_ax0] = T.float32(1)
+
+        @R.function
+        def main() -> R.Tensor((4,), dtype="float32"):
+            R.func_attr({"global_symbol": "main"})
+            cls = TestKillObject
+            storage: R.Object = R.vm.alloc_storage(
+                R.shape([16]), R.prim_value(0), R.dtype("float32")
+            )
+            alloc: R.Tensor((4,), dtype="float32") = R.vm.alloc_tensor(
+                storage, R.prim_value(0), R.shape([4]), R.dtype("float32")
+            )
+            _: R.Tuple = cls.full(alloc)
+            __1: R.Tuple = R.vm.kill_object(alloc)
+            x: R.Tensor((4,), dtype="float32") = alloc
+            alloc1: R.Tensor((4,), dtype="float32") = R.vm.alloc_tensor(
+                storage, R.prim_value(0), R.shape([4]), R.dtype("float32")
+            )
+            _1: R.Tuple = cls.full(alloc1)
+            _1_1: R.Tuple = R.vm.kill_object(alloc1)
+            y: R.Tensor((4,), dtype="float32") = alloc1
+            storage_1: R.Object = R.vm.alloc_storage(
+                R.shape([16]), R.prim_value(0), R.dtype("float32")
+            )
+            alloc2: R.Tensor((4,), dtype="float32") = R.vm.alloc_tensor(
+                storage_1, R.prim_value(0), R.shape([4]), R.dtype("float32")
+            )
+            _2: R.Tuple = cls.full1(alloc2)
+            z: R.Tensor((4,), dtype="float32") = alloc2
+            _2_1: R.Tuple = R.vm.kill_object(storage)
+            return z
+
+    mod = TestKillObject
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = codegen(mod, target, exec_mode)
+    dev = tvm.cpu()
+    vm = relax.VirtualMachine(ex, dev)
+
+    res = vm["main"]()
+    tvm.testing.assert_allclose(res.numpy(), np.ones((4,), "float32"))
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to