This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 78884ccd58 [Unity] Fix StructInfo Infer for `vm.alloc_tensor` (#14283)
78884ccd58 is described below

commit 78884ccd5845a7ff4eaaa43e194252070d64e442
Author: Siyuan Feng <[email protected]>
AuthorDate: Mon Mar 13 23:38:00 2023 +0800

    [Unity] Fix StructInfo Infer for `vm.alloc_tensor` (#14283)
    
    A hot fix for the struct info deduction for `vm.alloc_tensor`
---
 src/relax/op/op.cc                 | 4 +---
 tests/python/relax/test_op_misc.py | 8 ++++++++
 2 files changed, 9 insertions(+), 3 deletions(-)

diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index cf084d6d20..a603040394 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -444,15 +444,13 @@ 
TVM_REGISTER_GLOBAL("relax.op.vm.alloc_storage").set_body_typed(MakeVMAllocStora
 
 // vm alloc_tensor
 
-Expr InferShapeVMAllocTensor(const Call& call, DiagnosticContext diag_ctx) { 
return call->args[1]; }
-
 StructInfo InferStructInfoVMAllocTensor(const Call& call, const BlockBuilder& 
ctx) {
   DataType out_dtype;
   if (const auto* dtype_node = call->args[3].as<DataTypeImmNode>()) {
     const DataTypeImm dtype_imm = GetRef<DataTypeImm>(dtype_node);
     out_dtype = dtype_imm->value;
   }
-  if (const auto* output_shape = call->args[1].as<ShapeExprNode>()) {
+  if (const auto* output_shape = call->args[2].as<ShapeExprNode>()) {
     return TensorStructInfo(GetRef<Expr>(output_shape), out_dtype);
   }
   return TensorStructInfo(out_dtype, kUnknownNDim);
diff --git a/tests/python/relax/test_op_misc.py 
b/tests/python/relax/test_op_misc.py
index fd23911533..a10a1b5fe9 100644
--- a/tests/python/relax/test_op_misc.py
+++ b/tests/python/relax/test_op_misc.py
@@ -95,5 +95,13 @@ def test_implicit_op():
     assert isinstance(x[1][0], rx.TupleGetItem)
 
 
+def test_vm_alloc_tensor():
+    bb = rx.BlockBuilder()
+    storage = rx.Var("storage", rx.TensorStructInfo(dtype="float32"))
+    alloc = rx.op.vm.alloc_tensor(storage, offset=0, shape=rx.ShapeExpr([4, 
5]), dtype="float32")
+    alloc = bb.normalize(alloc)
+    tvm.ir.assert_structural_equal(alloc.struct_info, R.Tensor([4, 5], 
"float32"))
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to