This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new f500a0dd1a [Unity] Update relax.Function.ret_struct_info when mutated
(#15510)
f500a0dd1a is described below
commit f500a0dd1ab2dd833b5220d44a3147430edcb7d2
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Aug 9 22:19:09 2023 -0500
[Unity] Update relax.Function.ret_struct_info when mutated (#15510)
* [Unity] Update relax.Function.ret_struct_info when mutated
Prior to this commit, the `relax::ExprMutator` forwards the
original `ret_struct_info` when visiting a `relax::Function`. While
this allows the resulting function to have a more specific return type
if the mutation no longer allows the shapes to be propagated across
the entire body (e.g. preserving `R.Tensor(shape=[16,16])` even if
shape propagation resulted in `R.Tensor(ndim=2)`), this also preserves
information that is no longer correct due to the mutation.
For example, a mutator that implements `VisitVarDef_` to replace the
shape of a function parameter would expect that updated shape to
propagate through to the function return type. By retaining the
original return type, the mutator produces an incorrect return type.
This commit updates `relax::ExprMutator` to only forward the original
`ret_struct_info` if the mutated body's struct info is compatible with
it.
* Test updates with ret struct info
---------
Co-authored-by: Farshid <[email protected]>
---
src/relax/ir/expr_functor.cc | 20 ++-
tests/python/relax/test_dataflow_pattern.py | 31 ++---
tests/python/relax/test_expr_functor.py | 31 +++++
.../test_transform_combine_parallel_matmul.py | 155 ++++++++++-----------
.../python/relax/test_transform_convert_layout.py | 6 +-
.../relax/test_transform_dead_code_elimination.py | 16 +--
..._transform_legalize_ops_index_linear_algebra.py | 4 +-
.../test_transform_legalize_ops_manipulate.py | 4 +-
.../python/relax/test_transform_legalize_ops_nn.py | 8 +-
...st_transform_legalize_ops_search_statistical.py | 2 +-
.../relax/test_transform_rewrite_cuda_graph.py | 8 +-
.../relax/test_transform_to_mixed_precision.py | 12 +-
12 files changed, 166 insertions(+), 131 deletions(-)
diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc
index cb74400d7a..f0f0d29b51 100644
--- a/src/relax/ir/expr_functor.cc
+++ b/src/relax/ir/expr_functor.cc
@@ -585,11 +585,27 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
Expr body = this->VisitWithNewScope(op->body, params);
- // FuncStructInfo does not depend on Expr
if (all_params_unchanged && body.same_as(op->body)) {
+ // No changes to the function, return the original object
return GetRef<Expr>(op);
- } else {
+ } else if (IsBaseOf(GetStructInfo(body), op->ret_struct_info)) {
+ // If the function was mutated into a form that can no longer
+ // propagate shape information all the way to the return value, we
+ // may keep the return struct info. This is only allowed when the
+ // body produces a return value that is the same as, or more
+ // specific than, the pre-mutation struct info. For example, if
+ // the previous return value was `TensorStructInfo(shape=[16,16])`
+ // but the body only produced `TensorStructInfo(ndim=2)`, we can
+ // keep the more specific information.
return Function(params, body, op->ret_struct_info, op->is_pure, op->attrs);
+ } else {
+ // If the function was mutated such that the body produces an
+ // output that is incompatible with the original return struct
+ // info, the original return struct info should not be used. For
+ // example, if the previous return value was
+ // `TensorStructInfo(shape=[16,16])`, but the new return value is
+ // `TensorStructInfo(shape=[8,8])`.
+ return Function(params, body, NullOpt, op->is_pure, op->attrs);
}
}
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index 202db9b5b3..d3e9952d5b 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1053,7 +1053,7 @@ def get_qkv_proj_rewriter(
def test_combine_matmul_twice():
- @R.function
+ @R.function(private=True)
def qkv_x2(
x1: R.Tensor((2, 1024, 640), "float32"),
x2: R.Tensor((2, 1024, 640), "float32"),
@@ -1063,7 +1063,7 @@ def test_combine_matmul_twice():
w3: R.Tensor((640, 640), "float32"),
w4: R.Tensor((640, 640), "float32"),
w5: R.Tensor((640, 640), "float32"),
- ) -> R.Tensor:
+ ):
with R.dataflow():
lv0 = R.matmul(x1, w0)
lv1 = R.matmul(x1, w1)
@@ -1075,7 +1075,7 @@ def test_combine_matmul_twice():
R.output(out)
return out
- @R.function
+ @R.function(private=True)
def expected(
x1: R.Tensor((2, 1024, 640), "float32"),
x2: R.Tensor((2, 1024, 640), "float32"),
@@ -1085,7 +1085,7 @@ def test_combine_matmul_twice():
w3: R.Tensor((640, 640), "float32"),
w4: R.Tensor((640, 640), "float32"),
w5: R.Tensor((640, 640), "float32"),
- ) -> R.Tensor:
+ ):
with R.dataflow():
lv = R.concat((w0, w1, w2), axis=1)
lv1 = R.matmul(x1, lv)
@@ -1115,17 +1115,17 @@ def test_combine_matmul_twice():
inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1,
matmul2, matmul3
)
rewritten = rewrite_bindings(ctx, rewriter, qkv_x2)
- tvm.ir.assert_structural_equal(rewritten,
expected.with_attr("global_symbol", "qkv_x2"))
+ tvm.ir.assert_structural_equal(rewritten, expected)
def test_combine_matmul_emit_order():
- @R.function
+ @R.function(private=True)
def main(
x1: R.Tensor((2, 1024, 640), "float32"),
w0: R.Tensor((640, 640), "float32"),
w1: R.Tensor((640, 640), "float32"),
w2: R.Tensor((640, 640), "float32"),
- ) -> R.Tensor:
+ ):
with R.dataflow():
w0_t = R.permute_dims(w0, axes=None)
lv0 = R.matmul(x1, w0_t)
@@ -1138,13 +1138,13 @@ def test_combine_matmul_emit_order():
R.output(out)
return out
- @R.function
+ @R.function(private=True)
def expected(
x1: R.Tensor((2, 1024, 640), dtype="float32"),
w0: R.Tensor((640, 640), dtype="float32"),
w1: R.Tensor((640, 640), dtype="float32"),
w2: R.Tensor((640, 640), dtype="float32"),
- ) -> R.Tensor:
+ ):
with R.dataflow():
w0_t = R.permute_dims(w0, axes=None)
w1_t = R.permute_dims(w1, axes=None)
@@ -1173,7 +1173,7 @@ def test_combine_matmul_emit_order():
inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1,
matmul2, matmul3
)
rewritten = rewrite_bindings(ctx, rewriter, main)
- tvm.ir.assert_structural_equal(rewritten,
expected.with_attr("global_symbol", "main"))
+ tvm.ir.assert_structural_equal(rewritten, expected)
# make sure it builds
mod = tvm.IRModule()
@@ -1184,7 +1184,7 @@ def test_combine_matmul_emit_order():
def test_combine_transposed_matmul_twice():
- @R.function
+ @R.function(private=True)
def main(
x1: R.Tensor((2, 1024, 640), "float32"),
x2: R.Tensor((2, 1024, 640), "float32"),
@@ -1192,7 +1192,7 @@ def test_combine_transposed_matmul_twice():
w1: R.Tensor((640, 640), "float32"),
w2: R.Tensor((640, 640), "float32"),
w3: R.Tensor((640, 640), "float32"),
- ) -> R.Tensor:
+ ):
with R.dataflow():
w0_t = R.permute_dims(w0, axes=None)
lv0 = R.matmul(x1, w0_t)
@@ -1206,7 +1206,7 @@ def test_combine_transposed_matmul_twice():
R.output(out)
return out
- @R.function
+ @R.function(private=True)
def expected(
x1: R.Tensor((2, 1024, 640), dtype="float32"),
x2: R.Tensor((2, 1024, 640), dtype="float32"),
@@ -1214,7 +1214,7 @@ def test_combine_transposed_matmul_twice():
w1: R.Tensor((640, 640), dtype="float32"),
w2: R.Tensor((640, 640), dtype="float32"),
w3: R.Tensor((640, 640), dtype="float32"),
- ) -> R.Tensor:
+ ):
with R.dataflow():
lv: R.Tensor((1280, 640), dtype="float32") = R.concat((w0, w1),
axis=0)
lv1: R.Tensor((640, 1280), dtype="float32") = R.permute_dims(lv,
axes=None)
@@ -1271,8 +1271,7 @@ def test_combine_transposed_matmul_twice():
}
rewritten = rewrite_bindings(ctx, rewriter, main)
- print(rewritten.script())
- tvm.ir.assert_structural_equal(rewritten,
expected.with_attr("global_symbol", "main"))
+ tvm.ir.assert_structural_equal(rewritten, expected)
# make sure it builds
mod = tvm.IRModule()
diff --git a/tests/python/relax/test_expr_functor.py
b/tests/python/relax/test_expr_functor.py
index c18ab3e6f6..0daf9d4a1f 100644
--- a/tests/python/relax/test_expr_functor.py
+++ b/tests/python/relax/test_expr_functor.py
@@ -832,5 +832,36 @@ def test_call_mutator_super():
)
+def test_function_parameter_mutation():
+ @relax.expr_functor.mutator
+ class ParamMutator(PyExprMutator):
+ def __init__(self, shape_replacements):
+ super().__init__()
+ self.shape_replacements = shape_replacements
+
+ def visit_var_def_(self, var):
+ if var.name_hint in self.shape_replacements:
+ new_shape = self.shape_replacements[var.name_hint]
+ new_sinfo = relax.TensorStructInfo(new_shape,
dtype=var.struct_info.dtype)
+ return relax.Var(f"{var.name_hint}_with_new_shape", new_sinfo)
+ else:
+ return var
+
+ @R.function(private=True)
+ def before(
+ A: R.Tensor((16, 32), "float32"), B: R.Tensor((32, 64), "float32")
+ ) -> R.Tensor((16, 64), "float32"):
+ return R.matmul(A, B)
+
+ @R.function(private=True)
+ def expected(
+ A: R.Tensor((1, 32), "float32"), B: R.Tensor((32, 64), "float32")
+ ) -> R.Tensor((1, 64), "float32"):
+ return R.matmul(A, B)
+
+ after = ParamMutator({"A": (1, 32)}).visit_expr(before)
+ tvm.ir.assert_structural_equal(expected, after)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py
b/tests/python/relax/test_transform_combine_parallel_matmul.py
index 97211f0dd0..b06eddd2bb 100644
--- a/tests/python/relax/test_transform_combine_parallel_matmul.py
+++ b/tests/python/relax/test_transform_combine_parallel_matmul.py
@@ -349,36 +349,33 @@ def test_bias_activation():
def test_rhs_batched():
- @tvm.script.ir_module
- class four_matmul:
- @R.function
- def main(
- x: R.Tensor((1024, 640), "float32"),
- w0: R.Tensor((2, 640, 640), "float32"),
- w1: R.Tensor((640, 640), "float32"),
- w2: R.Tensor((2, 640, 640), "float32"),
- w3: R.Tensor((3, 4, 640, 640), "float32"),
- ) -> R.Tensor:
- with R.dataflow():
- lv0 = R.matmul(x, w0)
- lv1 = R.matmul(x, w1)
- lv2 = R.matmul(x, w2)
- lv3 = R.matmul(x, w3)
- out = (lv0, lv1, lv2, lv3)
- R.output(out)
- return out
+ @R.function(private=True)
+ def before(
+ x: R.Tensor((1024, 640), "float32"),
+ w0: R.Tensor((2, 640, 640), "float32"),
+ w1: R.Tensor((640, 640), "float32"),
+ w2: R.Tensor((2, 640, 640), "float32"),
+ w3: R.Tensor((3, 4, 640, 640), "float32"),
+ ):
+ with R.dataflow():
+ lv0 = R.matmul(x, w0)
+ lv1 = R.matmul(x, w1)
+ lv2 = R.matmul(x, w2)
+ lv3 = R.matmul(x, w3)
+ out = (lv0, lv1, lv2, lv3)
+ R.output(out)
+ return out
- mod = CombineParallelMatmul()(four_matmul)
+ after = CombineParallelMatmul()(tvm.IRModule.from_expr(before))["main"]
- @R.function
- def expected1(
+ @R.function(private=True)
+ def expected(
x: R.Tensor((1024, 640), dtype="float32"),
w0: R.Tensor((2, 640, 640), dtype="float32"),
w1: R.Tensor((640, 640), dtype="float32"),
w2: R.Tensor((2, 640, 640), dtype="float32"),
w3: R.Tensor((3, 4, 640, 640), dtype="float32"),
- ) -> R.Tensor:
- R.func_attr({"global_symbol": "main"})
+ ):
with R.dataflow():
lv = R.concat((w0, w2), axis=2)
lv1 = R.matmul(x, lv, out_dtype="float32")
@@ -391,7 +388,7 @@ def test_rhs_batched():
R.output(out)
return out
- tvm.ir.assert_structural_equal(mod["main"], expected1)
+ tvm.ir.assert_structural_equal(after, expected)
@tvm.script.ir_module
class four_matmul_incompatible_batches:
@@ -402,7 +399,7 @@ def test_rhs_batched():
w1: R.Tensor((3, 640, 640), "float32"),
w2: R.Tensor((2, 640, 640), "float32"),
w3: R.Tensor((2, 640, 640), "float32"),
- ) -> R.Tensor:
+ ):
with R.dataflow():
lv0 = R.matmul(x, w0)
lv1 = R.matmul(x, w1)
@@ -419,36 +416,34 @@ def test_rhs_batched():
def test_multiple_combine():
- @tvm.script.ir_module
- class multiple_combine:
- @R.function
- def main(
- x1: R.Tensor((2, 1024, 640), "float32"),
- x2: R.Tensor((2, 1024, 640), "float32"),
- w0: R.Tensor((640, 640), "float32"),
- w1: R.Tensor((640, 640), "float32"),
- w2: R.Tensor((640, 640), "float32"),
- w3: R.Tensor((640, 640), "float32"),
- w4: R.Tensor((640, 640), "float32"),
- b0: R.Tensor((640,), "float32"),
- b1: R.Tensor((640,), "float32"),
- ) -> R.Tensor:
- with R.dataflow():
- lv0 = R.matmul(x1, w0)
- lv3 = R.matmul(x2, w3)
- lv1 = R.matmul(x1, w1)
- lv5 = R.add(lv3, b0)
- lv2 = R.matmul(x1, w2)
- lv4 = R.matmul(x2, w4)
- lv6 = R.add(lv4, b1)
- out = (lv0, lv1, lv2, lv5, lv6)
- R.output(out)
- return out
+ @R.function(private=True)
+ def before(
+ x1: R.Tensor((2, 1024, 640), "float32"),
+ x2: R.Tensor((2, 1024, 640), "float32"),
+ w0: R.Tensor((640, 640), "float32"),
+ w1: R.Tensor((640, 640), "float32"),
+ w2: R.Tensor((640, 640), "float32"),
+ w3: R.Tensor((640, 640), "float32"),
+ w4: R.Tensor((640, 640), "float32"),
+ b0: R.Tensor((640,), "float32"),
+ b1: R.Tensor((640,), "float32"),
+ ):
+ with R.dataflow():
+ lv0 = R.matmul(x1, w0)
+ lv3 = R.matmul(x2, w3)
+ lv1 = R.matmul(x1, w1)
+ lv5 = R.add(lv3, b0)
+ lv2 = R.matmul(x1, w2)
+ lv4 = R.matmul(x2, w4)
+ lv6 = R.add(lv4, b1)
+ out = (lv0, lv1, lv2, lv5, lv6)
+ R.output(out)
+ return out
- mod = CombineParallelMatmul()(multiple_combine)
+ after = CombineParallelMatmul()(tvm.IRModule.from_expr(before))["main"]
- @R.function
- def expected1(
+ @R.function(private=True)
+ def expected(
x1: R.Tensor((2, 1024, 640), dtype="float32"),
x2: R.Tensor((2, 1024, 640), dtype="float32"),
w0: R.Tensor((640, 640), dtype="float32"),
@@ -458,8 +453,7 @@ def test_multiple_combine():
w4: R.Tensor((640, 640), dtype="float32"),
b0: R.Tensor((640,), dtype="float32"),
b1: R.Tensor((640,), dtype="float32"),
- ) -> R.Tensor:
- R.func_attr({"global_symbol": "main"})
+ ):
with R.dataflow():
lv = R.concat((w0, w1, w2), axis=1)
lv1 = R.matmul(x1, lv, out_dtype="float32")
@@ -478,36 +472,34 @@ def test_multiple_combine():
R.output(out)
return out
- tvm.ir.assert_structural_equal(mod["main"], expected1)
+ tvm.ir.assert_structural_equal(after, expected)
def test_check():
- @tvm.script.ir_module
- class multiple_combine:
- @R.function
- def main(
- x1: R.Tensor((2, 1024, 640), "float32"),
- x2: R.Tensor((2, 1024, 640), "float32"),
- w0: R.Tensor((640, 640), "float32"),
- w1: R.Tensor((640, 640), "float32"),
- w2: R.Tensor((640, 640), "float32"),
- w3: R.Tensor((640, 640), "float32"),
- w4: R.Tensor((640, 640), "float32"),
- ) -> R.Tensor:
- with R.dataflow():
- lv0 = R.matmul(x1, w0)
- lv1 = R.matmul(x1, w1)
- lv2 = R.matmul(x1, w2)
- lv3 = R.matmul(x2, w3)
- lv4 = R.matmul(x2, w4)
- out = (lv0, lv1, lv2, lv3, lv4)
- R.output(out)
- return out
+ @R.function(private=True)
+ def before(
+ x1: R.Tensor((2, 1024, 640), "float32"),
+ x2: R.Tensor((2, 1024, 640), "float32"),
+ w0: R.Tensor((640, 640), "float32"),
+ w1: R.Tensor((640, 640), "float32"),
+ w2: R.Tensor((640, 640), "float32"),
+ w3: R.Tensor((640, 640), "float32"),
+ w4: R.Tensor((640, 640), "float32"),
+ ):
+ with R.dataflow():
+ lv0 = R.matmul(x1, w0)
+ lv1 = R.matmul(x1, w1)
+ lv2 = R.matmul(x1, w2)
+ lv3 = R.matmul(x2, w3)
+ lv4 = R.matmul(x2, w4)
+ out = (lv0, lv1, lv2, lv3, lv4)
+ R.output(out)
+ return out
check = lambda *inp: len(inp[1]) > 2 # Ignore branches with two matmuls
- mod = CombineParallelMatmul(check)(multiple_combine)
+ after =
CombineParallelMatmul(check)(tvm.IRModule.from_expr(before))["main"]
- @R.function
+ @R.function(private=True)
def expected(
x1: R.Tensor((2, 1024, 640), dtype="float32"),
x2: R.Tensor((2, 1024, 640), dtype="float32"),
@@ -516,8 +508,7 @@ def test_check():
w2: R.Tensor((640, 640), dtype="float32"),
w3: R.Tensor((640, 640), dtype="float32"),
w4: R.Tensor((640, 640), dtype="float32"),
- ) -> R.Tensor:
- R.func_attr({"global_symbol": "main"})
+ ):
with R.dataflow():
lv = R.concat((w0, w1, w2), axis=1)
lv1 = R.matmul(x1, lv, out_dtype="float32")
@@ -531,7 +522,7 @@ def test_check():
R.output(out)
return out
- tvm.ir.assert_structural_equal(mod["main"], expected)
+ tvm.ir.assert_structural_equal(after, expected)
if __name__ == "__main__":
diff --git a/tests/python/relax/test_transform_convert_layout.py
b/tests/python/relax/test_transform_convert_layout.py
index 570a53b48f..417a5519e0 100644
--- a/tests/python/relax/test_transform_convert_layout.py
+++ b/tests/python/relax/test_transform_convert_layout.py
@@ -627,9 +627,7 @@ def test_conv2d_sum_negative_dims():
@I.ir_module
class Input:
@R.function
- def main(
- x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
- ) -> R.Tensor(None, "float32", ndim=4):
+ def main(x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3,
3), "float32")):
with R.dataflow():
gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
gv2: R.Tensor((2, 4), "float32") = R.sum(gv, axis=[-2, -1])
@@ -641,7 +639,7 @@ def test_conv2d_sum_negative_dims():
@R.function
def main(
x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
- ) -> R.Tensor(None, dtype="float32", ndim=4):
+ ):
with R.dataflow():
lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
diff --git a/tests/python/relax/test_transform_dead_code_elimination.py
b/tests/python/relax/test_transform_dead_code_elimination.py
index 12a3de6acb..2559eed34e 100644
--- a/tests/python/relax/test_transform_dead_code_elimination.py
+++ b/tests/python/relax/test_transform_dead_code_elimination.py
@@ -33,7 +33,7 @@ def test_simple():
x: R.Tensor((2, 3, 28, 28), dtype="float32"),
w: R.Tensor((4, 3, 3, 3), dtype="float32"),
bias: R.Tensor((26, 26), dtype="float32"),
- ) -> R.Tensor((2, 4, 26, 26), dtype="float32"):
+ ):
# block 0
with R.dataflow():
gv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
@@ -64,7 +64,7 @@ def test_simple():
x: R.Tensor((2, 3, 28, 28), dtype="float32"),
w: R.Tensor((4, 3, 3, 3), dtype="float32"),
bias: R.Tensor((26, 26), dtype="float32"),
- ) -> R.Tensor((2, 4, 26, 26), dtype="float32"):
+ ):
with R.dataflow():
gv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
gv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
@@ -94,7 +94,7 @@ def test_2block():
x: R.Tensor((2, 3, 28, 28), dtype="float32"),
w: R.Tensor((4, 3, 3, 3), dtype="float32"),
bias: R.Tensor((26, 26), dtype="float32"),
- ) -> R.Tensor((2, 4, 26, 26), dtype="float16"):
+ ):
# block 0
with R.dataflow():
gv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
@@ -126,7 +126,7 @@ def test_2block():
x: R.Tensor((2, 3, 28, 28), dtype="float32"),
w: R.Tensor((4, 3, 3, 3), dtype="float32"),
bias: R.Tensor((26, 26), dtype="float32"),
- ) -> R.Tensor((2, 4, 26, 26), dtype="float16"):
+ ):
with R.dataflow():
gv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
gv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
@@ -340,7 +340,7 @@ def test_unused_dfb():
def main(
x: R.Tensor((2, 3, 28, 28), dtype="float32"),
w: R.Tensor((4, 3, 3, 3), dtype="float32"),
- ) -> R.Tensor((2, 4, 26, 26), dtype="float16"):
+ ):
# block 0
with R.dataflow():
lv0: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(
@@ -369,7 +369,7 @@ def test_unused_dfb():
def main(
x: R.Tensor((2, 3, 28, 28), dtype="float32"),
w: R.Tensor((4, 3, 3, 3), dtype="float32"),
- ) -> R.Tensor((2, 4, 26, 26), dtype="float16"):
+ ):
# block 0
with R.dataflow():
lv0: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(
@@ -394,7 +394,7 @@ def test_unused_dfb2():
def main(
x: R.Tensor((2, 3, 28, 28), dtype="float32"),
w: R.Tensor((4, 3, 3, 3), dtype="float32"),
- ) -> R.Tensor((2, 4, 26, 26), dtype="float16"):
+ ):
# dead block
with R.dataflow():
lv0: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(
@@ -428,7 +428,7 @@ def test_unused_dfb2():
def main(
x: R.Tensor((2, 3, 28, 28), dtype="float32"),
w: R.Tensor((4, 3, 3, 3), dtype="float32"),
- ) -> R.Tensor((2, 4, 26, 26), dtype="float16"):
+ ):
gv_x = R.astype(x, dtype="float16")
gv_w = R.astype(x, dtype="float16")
diff --git
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index 2f1f8bb53b..0c84daa572 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -132,14 +132,14 @@ def test_strided_slice_no_strides():
@tvm.script.ir_module
class StridedSlice:
@R.function
- def main(x: R.Tensor((8, 9, 10, 10), "float32")) -> R.Tensor((4, 9,
10, 3), "float32"):
+ def main(x: R.Tensor((8, 9, 10, 10), "float32")) :
gv: R.Tensor((4, 9, 10, 3), "float32") = R.strided_slice(x,
axes=[0, 1, 3], begin=[1, 0, 2], end=[8, 9, 4])
return gv
@tvm.script.ir_module
class Expected:
@R.function
- def main(x: R.Tensor((8, 9, 10, 10), dtype="float32")) -> R.Tensor((4,
9, 10, 3), dtype="float32"):
+ def main(x: R.Tensor((8, 9, 10, 10), dtype="float32")):
gv = R.call_tir(Expected.strided_slice, (x,),
out_sinfo=R.Tensor((7, 9, 10, 2), dtype="float32"))
return gv
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index 09cad024df..6c9ca9d980 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -913,14 +913,14 @@ def test_squeeze_no_axis():
@tvm.script.ir_module
class Squeeze:
@R.function
- def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2,
3, 1, 4), "float32"):
+ def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) :
gv: R.Tensor((2, 3, 1, 4), "float32") = R.squeeze(x)
return gv
@tvm.script.ir_module
class Expected:
@R.function
- def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2,
3, 1, 4), "float32"):
+ def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) :
gv = R.call_tir(Expected.squeeze, (x,), R.Tensor((2, 3, 4),
dtype="float32"))
return gv
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index c2acd52105..b713712761 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -1725,7 +1725,7 @@ def test_cross_entropy_with_logits():
@tvm.script.ir_module
class Expected:
@R.function
- def main(x: R.Tensor((3,), dtype="float32"), y: R.Tensor((3,),
dtype="float32")) -> R.Tensor(dtype="float32", ndim=2):
+ def main(x: R.Tensor((3,), dtype="float32"), y: R.Tensor((3,),
dtype="float32")):
gv = R.call_tir(Expected.cross_entropy_with_logits, (x, y),
R.Tensor((), dtype="float32"))
return gv
@@ -1771,7 +1771,7 @@ def test_cross_entropy_with_logits_batch():
@tvm.script.ir_module
class Expected:
@R.function
- def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3),
dtype="float32")) -> R.Tensor(dtype="float32", ndim=2):
+ def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3),
dtype="float32")):
gv = R.call_tir(Expected.cross_entropy_with_logits, (x, y),
R.Tensor((), dtype="float32"))
return gv
@@ -1825,7 +1825,7 @@ def test_cross_entropy_with_logits_batch_symbolic():
@tvm.script.ir_module
class Expected:
@R.function
- def main(x: R.Tensor(("n", "m"), dtype="float32"), y: R.Tensor(("n",
"m"), dtype="float32")) -> R.Tensor(dtype="float32", ndim=2):
+ def main(x: R.Tensor(("n", "m"), dtype="float32"), y: R.Tensor(("n",
"m"), dtype="float32")):
gv = R.call_tir(Expected.cross_entropy_with_logits, (x, y),
R.Tensor((), dtype="float32"))
return gv
@@ -2326,7 +2326,7 @@ def test_batch_norm_symbolic():
T_add_2[v_ax0] = T_multiply_4[v_ax0] + T_multiply_6[v_ax0]
@R.function
- def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma:
R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"),
moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",),
dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"),
R.Tensor(("c",), dtype="float32"), R.Tensor(("c",), dtype="float32")):
+ def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma:
R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"),
moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",),
dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"),
R.Tensor(("T.max(c,h)",), dtype="float32"), R.Tensor(("T.max(c,h)",),
dtype="float32")):
n = T.int64()
h = T.int64()
w = T.int64()
diff --git
a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
index 5b81885790..979f3f113b 100644
--- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
+++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
@@ -1056,7 +1056,7 @@ def test_variance_no_keepdims():
T_divide[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] *
T.float32(0.10000000000000001)
@R.function
- def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((1,
3, 4, 1), dtype="float32"):
+ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((3,
4), dtype="float32"):
cls = Expected
gv = R.call_tir(cls.variance, (x,), out_sinfo=R.Tensor((3, 4),
dtype="float32"))
return gv
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py
b/tests/python/relax/test_transform_rewrite_cuda_graph.py
index 106147ef9a..f69d1c1390 100644
--- a/tests/python/relax/test_transform_rewrite_cuda_graph.py
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -38,7 +38,7 @@ def test_rewrite_cuda_graph():
@R.function
- def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2,4),
dtype="float32"):
# force_pure is expected because purity checking should be
disabled before this pass
R.func_attr({"relax.force_pure": True})
cls = Before
@@ -107,7 +107,7 @@ def test_rewrite_cuda_graph():
return gv
@R.function
- def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2,4),
dtype="float32"):
# this comes after RemovePurityChecking, so we expect purity to be
forced
R.func_attr({"relax.force_pure": True})
cls = Expected
@@ -258,7 +258,7 @@ def test_vm_builtin():
@R.function
- def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2,4),
dtype="float32"):
# force_pure is expected because purity checking should be
disabled before this pass
R.func_attr({"relax.force_pure": True})
cls = Before
@@ -314,7 +314,7 @@ def test_vm_builtin():
return gv
@R.function
- def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2,4),
dtype="float32"):
R.func_attr({"relax.force_pure": True})
cls = Expected
gv: R.Tuple(R.Object, R.Object) =
R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc",
(cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object,
R.Object),))
diff --git a/tests/python/relax/test_transform_to_mixed_precision.py
b/tests/python/relax/test_transform_to_mixed_precision.py
index cb179a8c25..4ddf47b462 100644
--- a/tests/python/relax/test_transform_to_mixed_precision.py
+++ b/tests/python/relax/test_transform_to_mixed_precision.py
@@ -625,7 +625,7 @@ def test_conv2d_softmax():
@R.function
def main(
x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((3, 3,
3, 3), dtype="float32")
- ) -> R.Tensor((2, 3, 26, 26), dtype="float32"):
+ ) -> R.Tensor((2, 3, 28, 28), dtype="float32"):
with R.dataflow():
lv: R.Tensor((3, 3, 3, 3), dtype="float16") = R.astype(w,
dtype="float16")
lv1: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x,
dtype="float16")
@@ -653,7 +653,7 @@ def test_conv2d_softmax():
@R.function
def main(
x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((3, 3,
3, 3), dtype="float32")
- ) -> R.Tensor((2, 3, 26, 26), dtype="float32"):
+ ) -> R.Tensor((2, 3, 28, 28), dtype="float32"):
with R.dataflow():
lv: R.Tensor((3, 3, 3, 3), dtype="float16") = R.astype(w,
dtype="float16")
lv1: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x,
dtype="float16")
@@ -872,7 +872,7 @@ def test_tuple_get():
x: R.Tensor((1, 4, 64, 64), dtype="float32"),
w: R.Tensor((512, 4, 3, 3), dtype="float32"),
bias: R.Tensor((512, 1, 1), dtype="float32"),
- ) -> R.Tensor((1, 256, 64, 64), dtype="float32"):
+ ):
with R.dataflow():
lv = R.astype(x, dtype="float16")
lv1 = R.astype(w, dtype="float16")
@@ -925,7 +925,7 @@ def test_conv2d_bias_fp32():
x: R.Tensor((1, 4, 64, 64), dtype="float32"),
w: R.Tensor((512, 4, 3, 3), dtype="float32"),
bias: R.Tensor((512,), dtype="float32"),
- ) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
+ ) -> R.Tensor((1, 512, 62, 62), dtype="float32"):
with R.dataflow():
lv: R.Tensor((1, 4, 64, 64), dtype="float16") = R.astype(x,
dtype="float16")
lv1: R.Tensor((512, 4, 3, 3), dtype="float16") = R.astype(w,
dtype="float16")
@@ -952,7 +952,7 @@ def test_conv2d_bias_fp32():
x: R.Tensor((1, 4, 64, 64), dtype="float32"),
w: R.Tensor((512, 4, 3, 3), dtype="float32"),
bias: R.Tensor((512,), dtype="float32"),
- ) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
+ ) -> R.Tensor((1, 512, 62, 62), dtype="float32"):
with R.dataflow():
lv: R.Tensor((1, 4, 64, 64), dtype="float16") = R.astype(x,
dtype="float16")
lv1: R.Tensor((512, 4, 3, 3), dtype="float16") = R.astype(w,
dtype="float16")
@@ -1020,7 +1020,7 @@ def test_convert_sig():
x: R.Tensor((1, 4, 64, 64), dtype="float32"),
w: R.Tensor((512, 4, 3, 3), dtype="float16"),
bias: R.Tensor((512,), dtype="float16"),
- ) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
+ ) -> R.Tensor((1, 512, 62, 62), dtype="float32"):
with R.dataflow():
lv = R.astype(x, dtype="float16")
lv142 = R.nn.conv2d(