This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new b78ead8d05 [Unity][Transform] Fix AMP tests (#14360)
b78ead8d05 is described below
commit b78ead8d055cac8b990bd74de783e34001232931
Author: Bohan Hou <[email protected]>
AuthorDate: Tue Mar 21 16:51:54 2023 -0400
[Unity][Transform] Fix AMP tests (#14360)
tests
---
.../relax/test_transform_to_mixed_precision.py | 43 +++++++++++++++++++++-
1 file changed, 42 insertions(+), 1 deletion(-)
diff --git a/tests/python/relax/test_transform_to_mixed_precision.py
b/tests/python/relax/test_transform_to_mixed_precision.py
index 7b542243dc..6b699b5165 100644
--- a/tests/python/relax/test_transform_to_mixed_precision.py
+++ b/tests/python/relax/test_transform_to_mixed_precision.py
@@ -293,7 +293,48 @@ def test_conv2d_relu_conv2d():
R.output(gv3)
return gv3
- _assert_test(Input, Expected)
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ w2: R.Tensor((4, 4, 3, 3), dtype="float32"),
+ ) -> R.Tensor((2, 4, 24, 24), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x,
dtype="float16")
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w,
dtype="float16")
+ lv2: R.Tensor((4, 4, 3, 3), dtype="float16") = R.astype(w2,
dtype="float16")
+ gv: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float16",
+ )
+ gv2: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.relu(gv)
+ lv3: R.Tensor((2, 4, 24, 24), dtype="float16") = R.nn.conv2d(
+ gv2,
+ lv2,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float16",
+ )
+ gv3: R.Tensor((2, 4, 24, 24), dtype="float32") = R.astype(lv3,
dtype="float32")
+ R.output(gv3)
+ return gv3
+
+ _assert_test(Input, Expected, Expected2)
def test_gemm_add_silu():