This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 78884ccd58 [Unity] Fix StructInfo Infer for `vm.alloc_tensor` (#14283)
78884ccd58 is described below
commit 78884ccd5845a7ff4eaaa43e194252070d64e442
Author: Siyuan Feng <[email protected]>
AuthorDate: Mon Mar 13 23:38:00 2023 +0800
[Unity] Fix StructInfo Infer for `vm.alloc_tensor` (#14283)
A hot fix for the struct info deduction for `vm.alloc_tensor`
---
src/relax/op/op.cc | 4 +---
tests/python/relax/test_op_misc.py | 8 ++++++++
2 files changed, 9 insertions(+), 3 deletions(-)
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index cf084d6d20..a603040394 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -444,15 +444,13 @@
TVM_REGISTER_GLOBAL("relax.op.vm.alloc_storage").set_body_typed(MakeVMAllocStora
// vm alloc_tensor
-Expr InferShapeVMAllocTensor(const Call& call, DiagnosticContext diag_ctx) {
return call->args[1]; }
-
StructInfo InferStructInfoVMAllocTensor(const Call& call, const BlockBuilder&
ctx) {
DataType out_dtype;
if (const auto* dtype_node = call->args[3].as<DataTypeImmNode>()) {
const DataTypeImm dtype_imm = GetRef<DataTypeImm>(dtype_node);
out_dtype = dtype_imm->value;
}
- if (const auto* output_shape = call->args[1].as<ShapeExprNode>()) {
+ if (const auto* output_shape = call->args[2].as<ShapeExprNode>()) {
return TensorStructInfo(GetRef<Expr>(output_shape), out_dtype);
}
return TensorStructInfo(out_dtype, kUnknownNDim);
diff --git a/tests/python/relax/test_op_misc.py
b/tests/python/relax/test_op_misc.py
index fd23911533..a10a1b5fe9 100644
--- a/tests/python/relax/test_op_misc.py
+++ b/tests/python/relax/test_op_misc.py
@@ -95,5 +95,13 @@ def test_implicit_op():
assert isinstance(x[1][0], rx.TupleGetItem)
+def test_vm_alloc_tensor():
+ bb = rx.BlockBuilder()
+ storage = rx.Var("storage", rx.TensorStructInfo(dtype="float32"))
+ alloc = rx.op.vm.alloc_tensor(storage, offset=0, shape=rx.ShapeExpr([4,
5]), dtype="float32")
+ alloc = bb.normalize(alloc)
+ tvm.ir.assert_structural_equal(alloc.struct_info, R.Tensor([4, 5],
"float32"))
+
+
if __name__ == "__main__":
tvm.testing.main()