This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 36e3c121b7 [Relax] Validate StructInfo annotations in well-formed
check (#17331)
36e3c121b7 is described below
commit 36e3c121b7dcfae3d5d5098186a7ca96e7ff27fc
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Sep 19 12:28:25 2024 -0500
[Relax] Validate StructInfo annotations in well-formed check (#17331)
* [Relax] Validate StructInfo annotations in well-formed check
Prior to this commit, the Relax well-formed checker verified that each
expression had a non-null `StructInfo` annotation, but did not perform
any validation on the contents of the `StructInfo` annotation.
This commit updates the Relax well-formed check to verify that the
`StructInfo` annotations are accurate by comparing against the
`StructInfo` that would be inferred for an expression. (This only
requires that the information is accurate, not that it is complete.
For example, an expression that is inferred to be
`R.Tensor(shape=[128,8], dtype="float32")` may have annotation of
`R.Tensor(ndim=2, dtype="float32"`, but may not have an annotation of
`R.Tensor(shape=[128,8], dtype="int32")`.)
* lint fix
* lint fix
---
src/relax/analysis/well_formed.cc | 43 +++++++++++
src/relax/op/op.cc | 21 ++++--
tests/python/relax/test_analysis_well_formed.py | 85 ++++++++++++++++++++++
tests/python/relax/test_ast_printer.py | 4 +-
tests/python/relax/test_frontend_from_fx.py | 10 +--
tests/python/relax/test_transform_decompose_ops.py | 4 +-
.../relax/test_transform_ipc_allreduce_rewrite.py | 4 +-
.../relax/test_transform_legalize_ops_ccl.py | 4 +-
.../test_transform_legalize_ops_create_datatype.py | 34 ++++-----
..._transform_legalize_ops_index_linear_algebra.py | 2 +-
.../test_transform_legalize_ops_manipulate.py | 51 +++++++------
.../python/relax/test_transform_legalize_ops_nn.py | 38 ++++++----
...st_transform_legalize_ops_search_statistical.py | 4 +-
.../python/relax/test_transform_realize_vdevice.py | 16 ++--
.../test_transform_static_plan_block_memory.py | 8 +-
.../relax/test_transform_to_mixed_precision.py | 12 +--
tests/python/relax/test_tvmscript_parser.py | 10 +--
tests/python/relax/test_vm_cuda_graph.py | 8 +-
tests/python/relax/test_vm_multi_device.py | 14 ++--
19 files changed, 268 insertions(+), 104 deletions(-)
diff --git a/src/relax/analysis/well_formed.cc
b/src/relax/analysis/well_formed.cc
index 7688c4a642..7873d5ce20 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -362,6 +362,49 @@ class WellFormedChecker : public relax::ExprVisitor,
<< err.what());
}
}
+
+ if (check_struct_info_ && call->struct_info_.defined()) {
+ // The `InferStructInfo` method isn't currently exposed by the
+ // Normalizer, and can only be called indirectly by normalizing
+ // an expression that does not yet have `StructInfo`.
+ auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_);
+ Call copied(call->op, call->args, call->attrs, call->sinfo_args);
+ Optional<Expr> normalized = NullOpt;
+ try {
+ normalized = dummy_builder->Normalize(copied);
+ } catch (std::exception& err) {
+ Malformed(Diagnostic::Error(call)
+ << "Each Relax expression must be able to have its
StructInfo inferred. "
+ << "However, inferring the struct info of expression " <<
GetRef<Call>(call)
+ << " resulted in the error: \n"
+ << err.what());
+ }
+ if (normalized.defined()) {
+ auto inferred_struct_info = GetStructInfo(normalized.value());
+ auto current_struct_info = Downcast<StructInfo>(call->struct_info_);
+
+ // An error should be raised if the annotated StructInfo is
+ // provably incorrect. This check is done using
+ // `StructInfoBaseCheck(...) < kFailL1`, because `kFailL1`
+ // represents cases that are neither provably correct nor
+ // provably incorrect. If this check were replaced with
+ // `!IsBaseOf(...)`, cases that are correct but not provably
+ // so would raise an exception.
+ //
+ // For example, if a dynamic size in the inferred StructInfo
+ // is equivalent to the expression used in the annotated
+ // StructInfo, but the TIR simplifications are not sufficient
+ // to prove that the two expressions are equivalent, we should
+ // not raise an error.
+ if (StructInfoBaseCheck(current_struct_info, inferred_struct_info) <
+ BaseCheckResult::kFailL1) {
+ Malformed(Diagnostic::Error(call)
+ << "All information in StructInfo annotations must be
correct. "
+ << "However, while the expression " << GetRef<Call>(call)
<< " is annotated as "
+ << current_struct_info << ", the expression outputs " <<
inferred_struct_info);
+ }
+ }
+ }
}
void VisitExpr_(const IfNode* op) final {
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index 3e0f0eba31..a7d97a59a1 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -1021,14 +1021,19 @@ StructInfo ReturnTensorToShapeStructInfo(const Call&
call, const BlockBuilder& c
ICHECK(call->args.size() == 1);
ICHECK(call->args[0]->struct_info_.defined());
const auto* tsinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
- ICHECK(tsinfo && tsinfo->shape.defined());
- ShapeExpr shape_expr = Downcast<ShapeExpr>(tsinfo->shape.value());
- ICHECK(shape_expr->values.size() == 1) << "relax.tensor_to_shape expected
argument to be 1-d, "
- << "but " << call << " has argument "
<< call->args[0]
- << " with struct info " <<
call->args[0]->struct_info_;
- const IntImmNode* ndim = shape_expr->values[0].as<IntImmNode>();
- ICHECK(ndim);
- return ShapeStructInfo(ndim->value);
+ ICHECK(tsinfo);
+ ICHECK_EQ(tsinfo->ndim, 1) << "relax.tensor_to_shape expected argument to be
1-d, "
+ << "but " << call << " has argument " <<
call->args[0]
+ << " with struct info " <<
call->args[0]->struct_info_;
+
+ if (tsinfo->shape.defined()) {
+ ShapeExpr shape_expr = Downcast<ShapeExpr>(tsinfo->shape.value());
+ const IntImmNode* ndim = shape_expr->values[0].as<IntImmNode>();
+ if (ndim) {
+ return ShapeStructInfo(ndim->value);
+ }
+ }
+ return ShapeStructInfo(kUnknownNDim);
}
RELAY_REGISTER_OP("relax.tensor_to_shape")
diff --git a/tests/python/relax/test_analysis_well_formed.py
b/tests/python/relax/test_analysis_well_formed.py
index 3db3efee1a..d9eefcfd0e 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -1295,5 +1295,90 @@ def
test_var_binding_with_incomplete_struct_info_must_be_consistent():
assert not rx.analysis.well_formed(main)
+def test_incomplete_struct_info_must_be_consistent():
+ """StructInfo annotations must be accurate
+
+ Even though StructInfo annotation may be less specific, the
+ information that they do contain must be correct.
+
+ """
+
+ @I.ir_module(check_well_formed=False)
+ class Module:
+ @R.function
+ def main(
+ A: R.Tensor(shape=[128, 32], dtype="float32"),
+ B: R.Tensor(shape=[128, 32], dtype="float32"),
+ ):
+ C: R.Tensor(ndim=3) = R.add(A, B)
+ return C
+
+ assert not rx.analysis.well_formed(Module)
+
+
+def test_struct_info_annotations_must_be_correct():
+ """StructInfo annotations must be correct
+
+ To be well-formed, the inferred struct info must not conflict with
+ the StructInfo annotations.
+
+ """
+
+ @I.ir_module(check_well_formed=False)
+ class Module:
+ @R.function
+ def main(
+ A: R.Tensor(shape=[128, 32], dtype="float32"),
+ B: R.Tensor(shape=[128, 32], dtype="float32"),
+ ):
+ C: R.Tensor(shape=[128, 32], dtype="int32") = R.add(A, B)
+ return C
+
+ assert not rx.analysis.well_formed(Module)
+
+
+def test_struct_info_may_be_incomplete():
+ """StructInfo annotations may be less specific
+
+ The StructInfo annotations are not required to be an exact match
+ to the inferred StructInfo, and may provide less specific
+ information than the inference would provide.
+
+ """
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(
+ A: R.Tensor(shape=[128, 32], dtype="float32"),
+ B: R.Tensor(shape=[128, 32], dtype="float32"),
+ ):
+ C: R.Object = R.add(A, B)
+ return C
+
+ assert rx.analysis.well_formed(Module)
+
+
+def test_incomplete_struct_info_must_be_consistent():
+ """StructInfo annotations must be accurate
+
+ Even though StructInfo annotation may be less specific, the
+ information that they do contain must be correct.
+
+ """
+
+ @I.ir_module(check_well_formed=False)
+ class Module:
+ @R.function
+ def main(
+ A: R.Tensor(shape=[128, 32], dtype="float32"),
+ B: R.Tensor(shape=[128, 32], dtype="float32"),
+ ):
+ C: R.Tensor(ndim=3) = R.add(A, B)
+ return C
+
+ assert not rx.analysis.well_formed(Module)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_ast_printer.py
b/tests/python/relax/test_ast_printer.py
index 6005ecb0fa..1df7dcf36f 100644
--- a/tests/python/relax/test_ast_printer.py
+++ b/tests/python/relax/test_ast_printer.py
@@ -366,8 +366,8 @@ def test_call_packed():
) -> R.Object:
m = T.int64()
z: R.Tensor((32, m), "float32") = R.multiply(x, y)
- w: R.Tensor = R.multiply(z, z)
- q: R.Tensor(ndim=2) = R.add(w, w)
+ w: R.Tensor(ndim=2) = R.multiply(z, z)
+ q: R.Tensor = R.add(w, w)
t = R.add(w, z)
sh: R.Shape = R.shape_of(t)
o: R.Object = R.call_packed(
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 78fc7abdf7..191ea4da5e 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -79,7 +79,7 @@ def test_conv1d():
out_layout="NCW",
out_dtype="float32",
)
- lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1])
+ lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1,
6, 1])
lv3: R.Tensor((1, 6, 4), dtype="float32") = R.add(lv1, lv2)
gv: R.Tensor((1, 6, 4), dtype="float32") = lv3
R.output(gv)
@@ -171,7 +171,7 @@ def test_conv1d_transpose():
out_layout="NCW",
out_dtype="float32",
)
- lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1])
+ lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1,
6, 1])
lv3: R.Tensor((1, 6, 6), dtype="float32") = R.add(lv1, lv2)
gv: R.Tensor((1, 6, 6), dtype="float32") = lv3
R.output(gv)
@@ -263,7 +263,7 @@ def test_conv2d():
out_layout="NCHW",
out_dtype="float32",
)
- lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1])
+ lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(w2,
[1, 6, 1, 1])
lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2)
gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv3
R.output(gv)
@@ -355,7 +355,7 @@ def test_conv2d_transpose():
out_layout="NCHW",
out_dtype="float32",
)
- lv2: R.Tensor((1, 3, 1, 1)) = R.reshape(w2, [1, 3, 1, 1])
+ lv2: R.Tensor((1, 3, 1, 1), dtype="float32") = R.reshape(w2,
[1, 3, 1, 1])
lv3: R.Tensor((1, 3, 16, 16), dtype="float32") = R.add(lv1,
lv2)
gv: R.Tensor((1, 3, 16, 16), dtype="float32") = lv3
R.output(gv)
@@ -447,7 +447,7 @@ def test_conv3d():
out_layout="NCDHW",
out_dtype="float32",
)
- lv2: R.Tensor((1, 6, 1, 1, 1)) = R.reshape(w2, [1, 6, 1, 1, 1])
+ lv2: R.Tensor((1, 6, 1, 1, 1), dtype="float32") =
R.reshape(w2, [1, 6, 1, 1, 1])
lv3: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.add(lv1,
lv2)
gv: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = lv3
R.output(gv)
diff --git a/tests/python/relax/test_transform_decompose_ops.py
b/tests/python/relax/test_transform_decompose_ops.py
index 4e5bcb82e9..2564913d79 100644
--- a/tests/python/relax/test_transform_decompose_ops.py
+++ b/tests/python/relax/test_transform_decompose_ops.py
@@ -360,14 +360,14 @@ def test_op_tensor_to_shape():
@I.ir_module
class Before:
@R.function
- def main(t: R.Tensor(ndim=1, dtype="int64")):
+ def main(t: R.Tensor([3], dtype="int64")):
gv: R.Shape(ndim=3) = R.tensor_to_shape(t)
return gv
@I.ir_module
class Expected:
@R.function
- def main(t: R.Tensor(dtype="int64", ndim=1)) -> R.Shape(ndim=3):
+ def main(t: R.Tensor([3], dtype="int64")) -> R.Shape(ndim=3):
x = T.int64()
x_1 = T.int64()
x_2 = T.int64()
diff --git a/tests/python/relax/test_transform_ipc_allreduce_rewrite.py
b/tests/python/relax/test_transform_ipc_allreduce_rewrite.py
index da85423aaf..fa68c16e69 100644
--- a/tests/python/relax/test_transform_ipc_allreduce_rewrite.py
+++ b/tests/python/relax/test_transform_ipc_allreduce_rewrite.py
@@ -83,7 +83,7 @@ def test_ipc_allreduce_spread_along_reshape():
alloc: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor(
# type: ignore
R.shape([m, n]), R.dtype("float16"), R.prim_value(0),
R.str("global")
)
- lv1: R.Tensor((m, n), dtype="float16") = R.reshape(alloc, (m *
n,)) # type: ignore
+ lv1: R.Tensor((m * n,), dtype="float16") = R.reshape(alloc, (m *
n,)) # type: ignore
alloc1: R.Tensor((m * n,), dtype="float16") =
R.builtin.alloc_tensor( # type: ignore
R.shape([m * n]), R.dtype("float16"), R.prim_value(0),
R.str("global")
)
@@ -103,7 +103,7 @@ def test_ipc_allreduce_spread_along_reshape():
alloc: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor(
# type: ignore
R.shape([m, n]), R.dtype("float16"), R.prim_value(0),
R.str("ipc_memory")
)
- lv1: R.Tensor((m, n), dtype="float16") = R.reshape( # type: ignore
+ lv1: R.Tensor((m * n,), dtype="float16") = R.reshape( # type:
ignore
alloc, R.shape([m * n])
)
alloc1: R.Tensor((m * n,), dtype="float16") =
R.builtin.alloc_tensor( # type: ignore
diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py
b/tests/python/relax/test_transform_legalize_ops_ccl.py
index 9ea4d21d61..923a8e8d97 100644
--- a/tests/python/relax/test_transform_legalize_ops_ccl.py
+++ b/tests/python/relax/test_transform_legalize_ops_ccl.py
@@ -101,8 +101,8 @@ def test_scatter_from_worker0():
@tvm.script.ir_module
class ScatterFromWorker0:
@R.function
- def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((5, 10),
"float32"):
- gv0: R.Tensor((5, 10), "float32") = R.ccl.scatter_from_worker0(x,
num_workers=2, axis=1)
+ def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10,5),
"float32"):
+ gv0: R.Tensor((10,5), "float32") = R.ccl.scatter_from_worker0(x,
num_workers=2, axis=1)
return gv0
@I.ir_module
diff --git a/tests/python/relax/test_transform_legalize_ops_create_datatype.py
b/tests/python/relax/test_transform_legalize_ops_create_datatype.py
index 7b2b2d2e76..a8af295ac3 100644
--- a/tests/python/relax/test_transform_legalize_ops_create_datatype.py
+++ b/tests/python/relax/test_transform_legalize_ops_create_datatype.py
@@ -160,19 +160,19 @@ def test_full_like():
@tvm.script.ir_module
class FullLike:
@R.function
- def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) ->
R.Tensor((2, 3), "float32"):
- gv: R.Tensor((2, 3), "float32") = R.full_like(x, v)
+ def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) ->
R.Tensor((2, 3), "int32"):
+ gv: R.Tensor((2, 3), "int32") = R.full_like(x, v)
return gv
@tvm.script.ir_module
class Expected:
@R.function
- def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) ->
R.Tensor((2, 3), "float32"):
- gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3),
dtype="float32"))
+ def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) ->
R.Tensor((2, 3), "int32"):
+ gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3),
dtype="int32"))
return gv
@T.prim_func(private=True)
- def full(rxplaceholder: T.Buffer((), "float32"), T_full:
T.Buffer((T.int64(2), T.int64(3)), "float32")):
+ def full(rxplaceholder: T.Buffer((), "float32"), T_full:
T.Buffer((T.int64(2), T.int64(3)), "int32")):
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_full"):
@@ -191,26 +191,26 @@ def test_full_like_constant_scalar_fill_value():
@tvm.script.ir_module
class FullLike:
@R.function
- def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"):
- gv: R.Tensor((2, 3), "float32") = R.full_like(x, R.const(-5,
"float32"))
+ def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
+ gv: R.Tensor((2, 3), "int32") = R.full_like(x, R.const(-5,
"float32"))
return gv
@tvm.script.ir_module
class Expected:
@R.function
- def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"):
- gv = R.call_tir(Expected.full, R.tuple(), R.Tensor((2, 3),
dtype="float32"))
+ def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
+ gv = R.call_tir(Expected.full, R.tuple(), R.Tensor((2, 3),
dtype="int32"))
return gv
@T.prim_func(private=True)
- def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+ def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")):
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_full"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads()
T.writes(T_full[ax0, ax1])
- T_full[ax0, ax1] = T.float32(-5)
+ T_full[ax0, ax1] = T.int32(-5)
# fmt: on
mod = LegalizeOps()(FullLike)
@@ -253,19 +253,19 @@ def test_full_like_symbolic():
@tvm.script.ir_module
class FullLike:
@R.function
- def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32"))
-> R.Tensor(("m", "n"), "float32"):
+ def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32"))
-> R.Tensor(("m", "n"), "int32"):
m = T.int64()
n = T.int64()
- gv: R.Tensor((m, n), "float32") = R.full_like(x, v)
+ gv: R.Tensor((m, n), "int32") = R.full_like(x, v)
return gv
@tvm.script.ir_module
class Expected:
@R.function
- def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32"))
-> R.Tensor(("m", "n"), "float32"):
+ def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32"))
-> R.Tensor(("m", "n"), "int32"):
m = T.int64()
n = T.int64()
- gv = R.call_tir(Expected.full, (v,), R.Tensor((m, n),
dtype="float32"))
+ gv = R.call_tir(Expected.full, (v,), R.Tensor((m, n),
dtype="int32"))
return gv
@T.prim_func(private=True)
@@ -273,13 +273,13 @@ def test_full_like_symbolic():
T.func_attr({"tir.noalias": True})
m = T.int64()
n = T.int64()
- T_full = T.match_buffer(var_T_full, [m, n], dtype="float32")
+ T_full = T.match_buffer(var_T_full, [m, n], dtype="int32")
for i0, i1 in T.grid(m, n):
with T.block("T_full"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[()])
T.writes(T_full[ax0, ax1])
- T_full[ax0, ax1] = rxplaceholder[()]
+ T_full[ax0, ax1] = T.int32(rxplaceholder[()])
# fmt: on
mod = LegalizeOps()(FullLike)
diff --git
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index d0aaddb1ca..2f4da5cf06 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -230,7 +230,7 @@ def test_strided_slice_no_strides():
class StridedSlice:
@R.function
def main(x: R.Tensor((8, 9, 10, 10), "float32")) :
- gv: R.Tensor((4, 9, 10, 3), "float32") = R.strided_slice(x,
axes=[0, 1, 3], begin=[1, 0, 2], end=[8, 9, 4])
+ gv: R.Tensor((7, 9, 10, 2), "float32") = R.strided_slice(x,
axes=[0, 1, 3], begin=[1, 0, 2], end=[8, 9, 4])
return gv
@tvm.script.ir_module
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index ba5d4d7d12..a0ecd3c73d 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -691,9 +691,12 @@ def test_data_dependent_reshape():
@tvm.script.ir_module
class DDReshape:
@R.function
- def main(x: R.Tensor((3, ), dtype="int64")):
- lv: R.Shape([3,]) = R.tensor_to_shape(x)
- gv = R.reshape(x, lv)
+ def main(
+ x: R.Tensor([2], dtype="int64"),
+ y: R.Tensor([16],dtype='float32'),
+ ):
+ lv: R.Shape(ndim=2) = R.tensor_to_shape(x)
+ gv = R.reshape(y, lv)
return gv
# fmt: on
@@ -704,29 +707,35 @@ def test_data_dependent_reshape():
# fmt: off
@I.ir_module
class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([2], dtype="int64"),
+ y: R.Tensor([16],dtype="float32"),
+ ) -> R.Tensor(ndim=2, dtype="float32"):
+ M = T.int64()
+ N = T.int64()
+ gv = R.call_pure_packed("vm.builtin.tensor_to_shape", x,
sinfo_args=(R.Shape(ndim=2),))
+ _ = R.match_cast(gv, R.Shape([M,N]))
+ _ = R.shape([M,N])
+ gv_1 = R.call_tir(Expected.reshape, (y,),
out_sinfo=R.Tensor([M,N], dtype="float32"))
+ return gv_1
+
@T.prim_func(private=True)
def reshape(
- rxplaceholder: T.Buffer((T.int64(3),), "int64"), var_T_reshape:
T.handle
+ rxplaceholder: T.Buffer(T.int64(16), "float32"),
+ var_T_reshape: T.handle,
):
T.func_attr({"tir.noalias": True})
- x = T.int64()
- T_reshape = T.match_buffer(var_T_reshape, (x,), "int64")
- # with T.block("root"):
- for ax0 in range(x):
+ M = T.int64()
+ N = T.int64()
+ T_reshape = T.match_buffer(var_T_reshape, [M,N], "float32")
+ for i,j in T.grid(M,N):
with T.block("T_reshape"):
- v_ax0 = T.axis.spatial(x, ax0)
- T.reads(rxplaceholder[v_ax0 % T.int64(3)])
- T.writes(T_reshape[v_ax0])
- T_reshape[v_ax0] = rxplaceholder[v_ax0 % T.int64(3)]
+ vi,vj = T.axis.remap('SS',[i,j])
+ T.reads(rxplaceholder[(vi*N + vj) % 16])
+ T.writes(T_reshape[vi,vj])
+ T_reshape[vi,vj] = rxplaceholder[(vi*N + vj) % 16]
- @R.function
- def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor(ndim=1,
dtype="int64"):
- x_1 = T.int64()
- gv: R.Shape([3]) =
R.call_pure_packed("vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape([3]),))
- y: R.Shape([x_1]) = R.match_cast(gv, R.Shape([x_1]))
- lv: R.Shape([x_1]) = R.shape([x_1])
- gv_1 = R.call_tir(Expected.reshape, (x,),
out_sinfo=R.Tensor((x_1,), dtype="int64"))
- return gv_1
# fmt: on
tvm.ir.assert_structural_equal(out_mod, Expected)
@@ -914,7 +923,7 @@ def test_squeeze_no_axis():
class Squeeze:
@R.function
def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) :
- gv: R.Tensor((2, 3, 1, 4), "float32") = R.squeeze(x)
+ gv: R.Tensor((2, 3, 4), "float32") = R.squeeze(x)
return gv
@tvm.script.ir_module
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 92d139d23b..d03d48968d 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -33,7 +33,7 @@ def test_conv1d():
class Conv1d:
@R.function
def main(x: R.Tensor((2, 128, 28), "float32"), w: R.Tensor((64, 16,
3), "float32")) -> R.Tensor((2, 64, 13), "float32"):
- gv: R.Tensor((2, 4, 13), "float32") = R.nn.conv1d(x, w,
strides=(2,), padding=(1,), dilation=(2,), groups=8)
+ gv: R.Tensor((2, 64, 13), "float32") = R.nn.conv1d(x, w,
strides=(2,), padding=(1,), dilation=(2,), groups=8)
return gv
@tvm.script.ir_module
@@ -210,7 +210,7 @@ def test_conv2d():
class Conv2d:
@R.function
def main(x: R.Tensor((2, 128, 28, 28), "float32"), w: R.Tensor((64,
16, 3, 3), "float32")) -> R.Tensor((2, 64, 13, 13), "float32"):
- gv: R.Tensor((2, 4, 13, 13), "float32") = R.nn.conv2d(x, w,
strides=(2, 2), padding=(1, 1), dilation=(2, 2), groups=8)
+ gv: R.Tensor((2, 64, 13, 13), "float32") = R.nn.conv2d(x, w,
strides=(2, 2), padding=(1, 1), dilation=(2, 2), groups=8)
return gv
@tvm.script.ir_module
@@ -3298,20 +3298,32 @@ def test_nll_loss():
@tvm.script.ir_module
class NLLLoss:
@R.function
- def main(predictions: R.Tensor((2, 3, 4, 5), "float32"), targets:
R.Tensor((2, 4, 5), "int64"), weights: R.Tensor((4,), "float32")) ->
R.Tensor((), "float32"):
- gv: R.Tensor((), "float32") = R.nn.nll_loss(predictions, targets,
weights, reduction="mean", ignore_index=-1)
+ def main(
+ predictions: R.Tensor((2, 3, 4, 5), "float32"),
+ targets: R.Tensor((2, 4, 5), "int64"),
+ weights: R.Tensor((3,), "float32"),
+ ) -> R.Tensor((), "float32"):
+ gv = R.nn.nll_loss(predictions, targets, weights,
reduction="mean", ignore_index=-1)
return gv
@tvm.script.ir_module
class Expected:
@R.function
- def main(predictions: R.Tensor((2, 3, 4, 5), dtype="float32"),
targets: R.Tensor((2, 4, 5), dtype="int64"), weights: R.Tensor((4,),
dtype="float32"),) -> R.Tensor((), dtype="float32"):
- # block 0
+ def main(
+ predictions: R.Tensor((2, 3, 4, 5), dtype="float32"),
+ targets: R.Tensor((2, 4, 5), dtype="int64"),
+ weights: R.Tensor((3,), dtype="float32"),
+ ) -> R.Tensor((), dtype="float32"):
gv = R.call_tir(Expected.nll_loss, (predictions, targets,
weights), R.Tensor((), dtype="float32"))
return gv
@T.prim_func(private=True)
- def nll_loss(rxplaceholder: T.Buffer((T.int64(2), T.int64(3),
T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(2),
T.int64(4), T.int64(5)), "int64"), rxplaceholder_2: T.Buffer(T.int64(4),
"float32"), T_divide: T.Buffer((), "float32"),):
+ def nll_loss(
+ predictions: T.Buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)), "float32"),
+ targets: T.Buffer((T.int64(2), T.int64(4), T.int64(5)),
"int64"),
+ weights: T.Buffer(T.int64(3), "float32"),
+ output: T.Buffer((), "float32"),
+ ):
# function attr dict
T.func_attr({"tir.noalias": True})
# body
@@ -3323,9 +3335,9 @@ def test_nll_loss():
for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
with T.block("nll_loss"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2],
rxplaceholder[v_ax0, rxplaceholder_1[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2],
rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]])
+ T.reads(targets[v_ax0, v_ax1, v_ax2], predictions[v_ax0,
targets[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2], weights[targets[v_ax0, v_ax1,
v_ax2]])
T.writes(nll_loss[v_ax0, v_ax1, v_ax2])
- nll_loss[v_ax0, v_ax1, v_ax2] =
T.Select(rxplaceholder_1[v_ax0, v_ax1, v_ax2] != T.int64(-1), (T.float32(0) -
rxplaceholder[v_ax0, rxplaceholder_1[v_ax0, v_ax1, v_ax2], v_ax1, v_ax2]) *
rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]], T.float32(0))
+ nll_loss[v_ax0, v_ax1, v_ax2] = T.Select(targets[v_ax0,
v_ax1, v_ax2] != T.int64(-1), (T.float32(0) - predictions[v_ax0, targets[v_ax0,
v_ax1, v_ax2], v_ax1, v_ax2]) * weights[targets[v_ax0, v_ax1, v_ax2]],
T.float32(0))
for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
with T.block("nll_loss_red"):
v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2])
@@ -3337,9 +3349,9 @@ def test_nll_loss():
for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
with T.block("nll_loss_1"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2],
rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]])
+ T.reads(targets[v_ax0, v_ax1, v_ax2],
weights[targets[v_ax0, v_ax1, v_ax2]])
T.writes(nll_loss_1[v_ax0, v_ax1, v_ax2])
- nll_loss_1[v_ax0, v_ax1, v_ax2] =
T.Select(rxplaceholder_1[v_ax0, v_ax1, v_ax2] != T.int64(-1),
rxplaceholder_2[rxplaceholder_1[v_ax0, v_ax1, v_ax2]], T.float32(0))
+ nll_loss_1[v_ax0, v_ax1, v_ax2] = T.Select(targets[v_ax0,
v_ax1, v_ax2] != T.int64(-1), weights[targets[v_ax0, v_ax1, v_ax2]],
T.float32(0))
for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
with T.block("nll_loss_red_1"):
v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2])
@@ -3351,8 +3363,8 @@ def test_nll_loss():
with T.block("T_divide"):
vi = T.axis.spatial(1, T.int64(0))
T.reads(nll_loss_red[()], nll_loss_red_1[()])
- T.writes(T_divide[()])
- T_divide[()] = nll_loss_red[()] / nll_loss_red_1[()]
+ T.writes(output[()])
+ output[()] = nll_loss_red[()] / nll_loss_red_1[()]
# fmt: on
mod = LegalizeOps()(NLLLoss)
tvm.ir.assert_structural_equal(mod, Expected)
diff --git
a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
index 2a28151dbe..f8dab89815 100644
--- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
+++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
@@ -999,8 +999,8 @@ def test_variance_no_keepdims():
@tvm.script.ir_module
class Variance:
@R.function
- def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((1, 3, 4,
1), "float32"):
- gv: R.Tensor((1, 3, 4, 1), "float32") = R.variance(x, [0, 3],
keepdims=False)
+ def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((3, 4),
"float32"):
+ gv: R.Tensor((3, 4), "float32") = R.variance(x, [0, 3],
keepdims=False)
return gv
@I.ir_module
diff --git a/tests/python/relax/test_transform_realize_vdevice.py
b/tests/python/relax/test_transform_realize_vdevice.py
index 4c530d5e49..fa64282184 100644
--- a/tests/python/relax/test_transform_realize_vdevice.py
+++ b/tests/python/relax/test_transform_realize_vdevice.py
@@ -61,8 +61,9 @@ def test_dataflow_binding():
y1 = y
x2 = x1
y2 = y1
- lv0: R.Tensor((2, 3), "float32", "llvm") = R.add(x2, y2)
- gv: R.Tensor((2, 3), "float32", "llvm") = R.multiply(lv0, z)
+ x2 = R.hint_on_device(x2, tvm.cpu())
+ lv0 = R.add(x2, y2)
+ gv = R.multiply(lv0, z)
R.output(gv)
return gv
@@ -91,6 +92,7 @@ def test_dataflow_binding():
y1: R.Tensor((2, 3), "float32", "llvm") = y
x2: R.Tensor((2, 3), "float32", "llvm") = x1
y2: R.Tensor((2, 3), "float32", "llvm") = y1
+ x2: R.Tensor((2, 3), "float32", "llvm") = x2
lv0: R.Tensor((2, 3), "float32", "llvm") = R.add(x2, y2)
gv: R.Tensor((2, 3), "float32", "llvm") = R.multiply(lv0, z)
R.output(gv)
@@ -121,7 +123,8 @@ def test_binding():
y1 = y
x2 = x1
y2 = y1
- s: R.Tensor((2, 3), "float32", "llvm") = R.add(x2, y2)
+ x2 = R.hint_on_device(x2, tvm.cpu())
+ s = R.add(x2, y2)
m = R.multiply(s, z)
return m
@@ -146,6 +149,7 @@ def test_binding():
y1: R.Tensor((2, 3), "float32", "llvm") = y
x2: R.Tensor((2, 3), "float32", "llvm") = x1
y2: R.Tensor((2, 3), "float32", "llvm") = y1
+ x2: R.Tensor((2, 3), "float32", "llvm") = x2
s: R.Tensor((2, 3), "float32", "llvm") = R.add(x2, y2)
m: R.Tensor((2, 3), "float32", "llvm") = R.multiply(s, z)
return m
@@ -275,10 +279,11 @@ def test_multi_device():
z: R.Tensor((2, 3), "float32"),
) -> R.Tensor((2, 3), "float32", "cuda"):
with R.dataflow():
- lv0: R.Tensor((2, 3), "float32", "llvm") = R.add(x, y)
+ lv0 = R.add(x, y)
+ lv0 = R.hint_on_device(lv0, tvm.cpu())
lv1 = R.to_vdevice(lv0, "cuda")
lv2 = R.add(z, z)
- gv: R.Tensor((2, 3), "float32", "cuda") = R.multiply(lv1, lv2)
+ gv = R.multiply(lv1, lv2)
R.output(gv)
return gv
@@ -304,6 +309,7 @@ def test_multi_device():
) -> R.Tensor((2, 3), "float32", "cuda"):
with R.dataflow():
lv0: R.Tensor((2, 3), "float32", "llvm") = R.add(x, y)
+ lv0: R.Tensor((2, 3), "float32", "llvm") = lv0
lv1: R.Tensor((2, 3), "float32", "cuda") = R.to_vdevice(lv0,
"cuda")
lv2: R.Tensor((2, 3), "float32", "cuda") = R.add(z, z)
gv: R.Tensor((2, 3), "float32", "cuda") = R.multiply(lv1, lv2)
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py
b/tests/python/relax/test_transform_static_plan_block_memory.py
index f9e632d348..1150827b19 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -1386,11 +1386,11 @@ def test_add():
)
cls.cumsum(probs, lv1, alloc1)
cumsum: R.Tensor((batch_size, vocab_size), dtype="float32") =
alloc1
- lv1_1: R.Tensor((batch_size, vocab_size), dtype="int32") =
R.call_packed(
+ lv1_1: R.Tensor((batch_size, vocab_size), dtype="float32") =
R.call_packed(
"vm.builtin.reshape",
cumsum,
R.shape([batch_size, vocab_size]),
- sinfo_args=(R.Tensor((batch_size, vocab_size),
dtype="float"),),
+ sinfo_args=(R.Tensor((batch_size, vocab_size),
dtype="float32"),),
)
return lv1_1
@@ -1403,7 +1403,7 @@ def test_add():
@R.function
def main(
probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32")
- ) -> R.Tensor(("batch_size", "vocab_size"), dtype="int32"):
+ ) -> R.Tensor(("batch_size", "vocab_size"), dtype="float32"):
batch_size = T.int64()
vocab_size = T.int64()
R.func_attr(
@@ -1437,7 +1437,7 @@ def test_add():
)
cls.cumsum(probs, lv1, alloc1)
cumsum: R.Tensor((batch_size, vocab_size), dtype="float32") =
alloc1
- lv1_1: R.Tensor((batch_size, vocab_size), dtype="int32") =
R.call_packed(
+ lv1_1: R.Tensor((batch_size, vocab_size), dtype="float32") =
R.call_packed(
"vm.builtin.reshape",
cumsum,
R.shape([batch_size, vocab_size]),
diff --git a/tests/python/relax/test_transform_to_mixed_precision.py
b/tests/python/relax/test_transform_to_mixed_precision.py
index ed10fc95c7..658f80a06e 100644
--- a/tests/python/relax/test_transform_to_mixed_precision.py
+++ b/tests/python/relax/test_transform_to_mixed_precision.py
@@ -906,15 +906,15 @@ def test_conv2d_bias_fp32():
) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
# block 0
with R.dataflow():
- lv142: R.Tensor((1, 4, 64, 64), dtype="float32") = R.nn.conv2d(
+ lv142: R.Tensor((1, 512, 62, 62), dtype="float32") =
R.nn.conv2d(
x,
w,
strides=[1, 1],
padding=[0, 0, 0, 0],
out_dtype="float32",
)
- lv143: R.Tensor((1, 4, 1, 1), dtype="float32") =
R.reshape(bias, (1, 512, 1, 1))
- lv144: R.Tensor((1, 4, 64, 64), dtype="float32") =
R.add(lv142, lv143)
+ lv143: R.Tensor((1, 512, 1, 1), dtype="float32") =
R.reshape(bias, (1, 512, 1, 1))
+ lv144: R.Tensor((1, 512, 62, 62), dtype="float32") =
R.add(lv142, lv143)
R.output(lv144)
return lv144
@@ -1001,15 +1001,15 @@ def test_convert_sig():
) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
# block 0
with R.dataflow():
- lv142: R.Tensor((1, 4, 64, 64), dtype="float32") = R.nn.conv2d(
+ lv142: R.Tensor((1, 512, 62, 62), dtype="float32") =
R.nn.conv2d(
x,
w,
strides=[1, 1],
padding=[0, 0, 0, 0],
out_dtype="float32",
)
- lv143: R.Tensor((1, 4, 1, 1), dtype="float32") =
R.reshape(bias, (1, 512, 1, 1))
- lv144: R.Tensor((1, 4, 64, 64), dtype="float32") =
R.add(lv142, lv143)
+ lv143: R.Tensor((1, 512, 1, 1), dtype="float32") =
R.reshape(bias, (1, 512, 1, 1))
+ lv144: R.Tensor((1, 512, 62, 62), dtype="float32") =
R.add(lv142, lv143)
R.output(lv144)
return lv144
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index fa62d14848..3e64c928ae 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -882,8 +882,8 @@ def test_annotation():
) -> R.Object:
m = T.int64()
z: R.Tensor((32, m), "float32") = R.multiply(x, y)
- w: R.Tensor = R.multiply(z, z)
- q: R.Tensor(ndim=2) = R.add(w, w)
+ w: R.Tensor(ndim=2) = R.multiply(z, z)
+ q: R.Tensor = R.add(w, w)
t = R.add(w, z)
sh: R.Shape = R.call_packed("shape_of", x, sinfo_args=R.Shape)
lv: R.Tensor(sh, dtype="float32") = R.reshape(x, sh)
@@ -902,9 +902,9 @@ def test_annotation():
sh = bindings[4].var
_check_struct_info(bindings[0], relax.TensorStructInfo([32, m], "float32"))
- _check_struct_info(bindings[1], relax.TensorStructInfo(dtype="", ndim=-1))
- _check_struct_info(bindings[2], relax.TensorStructInfo(dtype="", ndim=2))
- _check_struct_info(bindings[3], relax.TensorStructInfo(dtype="", ndim=-1))
+ _check_struct_info(bindings[1], relax.TensorStructInfo(dtype="", ndim=2))
+ _check_struct_info(bindings[2], relax.TensorStructInfo(dtype="", ndim=-1))
+ _check_struct_info(bindings[3], relax.TensorStructInfo(dtype="", ndim=2))
_check_struct_info(bindings[4], relax.ShapeStructInfo(ndim=-1))
_check_struct_info(bindings[5], relax.TensorStructInfo(sh))
_check_struct_info(bindings[6], relax.ObjectStructInfo())
diff --git a/tests/python/relax/test_vm_cuda_graph.py
b/tests/python/relax/test_vm_cuda_graph.py
index 49ebcc1d05..b6c8cdfdee 100644
--- a/tests/python/relax/test_vm_cuda_graph.py
+++ b/tests/python/relax/test_vm_cuda_graph.py
@@ -36,13 +36,13 @@ class Module:
R.func_attr({"global_symbol": "main"})
gv: R.Tuple(R.Object, R.Object) =
R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc",
(cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object,
R.Object),))
storage: R.Object = gv[0]
- alloc: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage,
R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
+ alloc = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)),
R.dtype("float32"))
_: R.Tuple = cls.add(x, alloc)
storage1: R.Object = gv[1]
gv1: R.Tuple(R.Tensor(dtype="float32"), R.Object, R.Object) = (alloc,
storage1, storage)
gv2: R.Tuple(R.Tensor((16, 16), dtype="float32")) =
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture",
(cls.cuda_graph_capture, gv1, R.prim_value(0)),
sinfo_args=(R.Tuple(R.Tensor((16, 16), dtype="float32")),))
storage2: R.Object = R.vm.alloc_storage(R.shape((1024,)),
R.prim_value(0), R.dtype("uint8"))
- alloc3: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage2,
R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
+ alloc3 = R.vm.alloc_tensor(storage2, R.prim_value(0), R.shape((16,
16)), R.dtype("float32"))
lv4: R.Tensor((16, 16), dtype="float32") = gv2[0]
_3: R.Tuple = cls.add(lv4, alloc3)
lv5: R.Tensor(dtype="float32") = alloc3
@@ -71,12 +71,12 @@ class Module:
cls = Module
R.func_attr({"global_symbol": "cuda_graph_capture"})
lv0: R.Tensor((16, 16), dtype="float32") = alloc
- alloc1: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage1,
R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
+ alloc1 = R.vm.alloc_tensor(storage1, R.prim_value(0), R.shape((16,
16)), R.dtype("float32"))
_1: R.Tuple = cls.add(lv0, alloc1)
lv1: R.Tensor(dtype="float32") = alloc1
lv2: R.Tuple(R.Tensor(dtype="float32")) = (lv1,)
lv3: R.Tensor(dtype="float32") = lv2[0]
- alloc2: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage,
R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
+ alloc2 = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16,
16)), R.dtype("float32"))
_2: R.Tuple = cls.add(lv3, alloc2)
lv4: R.Tensor(dtype="float32") = alloc2
gv: R.Tuple(R.Tensor(dtype="float32")) = (lv4,)
diff --git a/tests/python/relax/test_vm_multi_device.py
b/tests/python/relax/test_vm_multi_device.py
index ec2fbd1cdf..73c78d70f0 100644
--- a/tests/python/relax/test_vm_multi_device.py
+++ b/tests/python/relax/test_vm_multi_device.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Test eliminate common subexpr pass"""
+
from typing import List
import tvm
from tvm import relax
@@ -61,11 +62,12 @@ def test_multi_cpu():
z: R.Tensor((4, 5), "float32"),
) -> R.Tensor((2, 5), "float32"):
with R.dataflow():
- lv0: R.Tensor((2, 4), "float32", "llvm:0") = R.matmul(x, y) #
noqa: F722
+ lv0 = R.matmul(x, y)
+ lv0 = R.hint_on_device(lv0, tvm.cpu(0))
lv1: R.Tensor((2, 4), "float32", "llvm:1") = R.to_vdevice( #
noqa: F722
- lv0, "llvm:1" # noqa: F722
+ lv0, "llvm:1"
)
- gv = R.matmul(lv1, z) # noqa: F722
+ gv = R.matmul(lv1, z)
R.output(gv)
return gv
@@ -109,11 +111,13 @@ def test_multi_gpu():
with R.dataflow():
lv0: R.Tensor((2, 4), "float32", "cuda:0") = R.matmul(a, b) #
noqa: F722
lv1: R.Tensor((2, 4), "float32", "cuda:1") = R.to_vdevice( #
noqa: F722
- lv0, "cuda:1" # noqa: F722
+ lv0,
+ "cuda:1", # noqa: F722
)
lv2: R.Tensor((2, 5), "float32", "cuda:1") = R.matmul(lv1, c)
# noqa: F722
lv3: R.Tensor((2, 5), "float32", "cuda:2") = R.to_vdevice( #
noqa: F722
- lv2, "cuda:2" # noqa: F722
+ lv2,
+ "cuda:2", # noqa: F722
)
gv: R.Tensor((2, 6), "float32", "cuda:2") = R.matmul(lv3, d)
# noqa: F722
R.output(gv)