This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new c6cf403590 [Unity][Transform] AMP out_dtype=float16 testcases (#14358)
c6cf403590 is described below
commit c6cf403590cac886a8caf732d9ac35ebaa20f26f
Author: Bohan Hou <[email protected]>
AuthorDate: Tue Mar 21 12:56:59 2023 -0400
[Unity][Transform] AMP out_dtype=float16 testcases (#14358)
Add test cases to test out_dtype="float16"
---
.../relax/test_transform_to_mixed_precision.py | 282 ++++++++++++++++++++-
1 file changed, 273 insertions(+), 9 deletions(-)
diff --git a/tests/python/relax/test_transform_to_mixed_precision.py
b/tests/python/relax/test_transform_to_mixed_precision.py
index b9409bff52..7b542243dc 100644
--- a/tests/python/relax/test_transform_to_mixed_precision.py
+++ b/tests/python/relax/test_transform_to_mixed_precision.py
@@ -23,9 +23,12 @@ from tvm.relax.transform import ToMixedPrecision
from tvm.script.parser import ir as I, relax as R
-def _assert_test(input, expected):
+def _assert_test(input, expected, expected2):
mod = ToMixedPrecision()(input)
tvm.ir.assert_structural_equal(mod, expected)
+ mod = ToMixedPrecision(out_dtype="float16")(input)
+ print(mod.script())
+ tvm.ir.assert_structural_equal(mod, expected2)
def test_conv2d():
@@ -64,7 +67,32 @@ def test_conv2d():
R.output(gv)
return gv
- _assert_test(Input, Expected)
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor((2, 4, 26, 26), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x,
dtype="float16")
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w,
dtype="float16")
+ lv2: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float16",
+ )
+ gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.astype(lv2,
dtype="float32")
+ R.output(gv)
+ return gv
+
+ _assert_test(Input, Expected, Expected2)
def test_conv2d_relu():
@@ -107,7 +135,33 @@ def test_conv2d_relu():
R.output(gv)
return gv
- _assert_test(Input, Expected)
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor((2, 4, 26, 26), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x,
dtype="float16")
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w,
dtype="float16")
+ lv_1: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float16",
+ )
+ lv2: R.Tensor((2, 4, 26, 26), dtype="float16") =
R.nn.relu(lv_1)
+ gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.astype(lv2,
dtype="float32")
+ R.output(gv)
+ return gv
+
+ _assert_test(Input, Expected, Expected2)
def test_relu_conv2d_relu():
@@ -152,7 +206,34 @@ def test_relu_conv2d_relu():
R.output(gv2)
return gv2
- _assert_test(Input, Expected)
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor((2, 4, 26, 26), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w,
dtype="float16")
+ x0: R.Tensor((2, 3, 28, 28), dtype="float32") = R.nn.relu(x)
+ lv1: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x0,
dtype="float16")
+ gv: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.conv2d(
+ lv1,
+ lv,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float16",
+ )
+ lv2: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.relu(gv)
+ gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.astype(lv2,
dtype="float32")
+ R.output(gv2)
+ return gv2
+
+ _assert_test(Input, Expected, Expected2)
def test_conv2d_relu_conv2d():
@@ -250,7 +331,25 @@ def test_gemm_add_silu():
R.output(gv2)
return gv2
- _assert_test(Input, Expected)
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ x: R.Tensor((2, 320), dtype="float32"),
+ w1: R.Tensor((320, 1280), dtype="float32"),
+ w2: R.Tensor((2, 1280), dtype="float32"),
+ ) -> R.Tensor((2, 1280), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((2, 320), dtype="float16") = R.astype(x,
dtype="float16")
+ lv1: R.Tensor((320, 1280), dtype="float16") = R.astype(w1,
dtype="float16")
+ gv0: R.Tensor((2, 1280), dtype="float16") = R.matmul(lv, lv1,
out_dtype="float16")
+ lv2: R.Tensor((2, 1280), dtype="float32") = R.astype(gv0,
dtype="float32")
+ gv1: R.Tensor((2, 1280), dtype="float32") = R.add(lv2, w2)
+ gv2: R.Tensor((2, 1280), dtype="float32") = R.nn.silu(gv1)
+ R.output(gv2)
+ return gv2
+
+ _assert_test(Input, Expected, Expected2)
def test_tuple():
@@ -342,7 +441,75 @@ def test_tuple():
R.output(gv7)
return gv7
- _assert_test(Input, Expected)
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ w_2: R.Tensor((4, 4, 3, 3), dtype="float32"),
+ ) -> R.Tensor((2, 4, 24, 24), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x,
dtype="float16")
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w,
dtype="float16")
+ lv2: R.Tensor((4, 4, 3, 3), dtype="float16") = R.astype(w_2,
dtype="float16")
+ gv: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float16",
+ )
+ gv2: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float16",
+ )
+ gv3: R.Tuple(
+ R.Tensor((2, 4, 26, 26), dtype="float16"),
+ R.Tensor((2, 4, 26, 26), dtype="float16"),
+ ) = (gv, gv2)
+ gv4: R.Tuple(
+ R.Tuple(
+ R.Tensor((2, 4, 26, 26), dtype="float16"),
+ R.Tensor((2, 4, 26, 26), dtype="float16"),
+ ),
+ R.Tensor((2, 4, 26, 26), dtype="float16"),
+ ) = (gv3, gv2)
+ gv5: R.Tuple(
+ R.Tensor((2, 4, 26, 26), dtype="float16"),
+ R.Tensor((2, 4, 26, 26), dtype="float16"),
+ ) = gv4[0]
+ gv6: R.Tensor((2, 4, 26, 26), dtype="float16") = gv5[0]
+ lv3: R.Tensor((2, 4, 24, 24), dtype="float16") = R.nn.conv2d(
+ gv6,
+ lv2,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float16",
+ )
+ gv7: R.Tensor((2, 4, 24, 24), dtype="float32") = R.astype(lv3,
dtype="float32")
+ R.output(gv7)
+ return gv7
+
+ _assert_test(Input, Expected, Expected2)
def test_concat_matmul():
@@ -376,7 +543,24 @@ def test_concat_matmul():
R.output(lv14)
return lv14
- _assert_test(Input, Expected)
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ lv10: R.Tensor((2, 160), dtype="float32"),
+ lv12: R.Tensor((2, 160), dtype="float32"),
+ w: R.Tensor((320, 1280), dtype="float32"),
+ ) -> R.Tensor((2, 1280), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((320, 1280), dtype="float16") = R.astype(w,
dtype="float16")
+ lv13: R.Tensor((2, 320), dtype="float32") = R.concat((lv10,
lv12), axis=-1)
+ lv1: R.Tensor((2, 320), dtype="float16") = R.astype(lv13,
dtype="float16")
+ lv2: R.Tensor((2, 1280), dtype="float16") = R.matmul(lv1, lv,
out_dtype="float16")
+ lv14: R.Tensor((2, 1280), dtype="float32") = R.astype(lv2,
dtype="float32")
+ R.output(lv14)
+ return lv14
+
+ _assert_test(Input, Expected, Expected2)
def test_conv2d_softmax():
@@ -421,7 +605,34 @@ def test_conv2d_softmax():
R.output(gv2)
return gv2
- _assert_test(Input, Expected)
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((3, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor((2, 3, 26, 26), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((3, 3, 3, 3), dtype="float16") = R.astype(w,
dtype="float16")
+ lv1: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x,
dtype="float16")
+ gv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.nn.conv2d(
+ lv1,
+ lv,
+ strides=[1, 1],
+ padding=[1, 1, 1, 1],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float16",
+ )
+ gv1: R.Tensor((2, 3, 28, 28), dtype="float32") =
R.nn.softmax(x, axis=1)
+ lv2: R.Tensor((2, 3, 28, 28), dtype="float32") = R.astype(gv,
dtype="float32")
+ gv2: R.Tensor((2, 3, 28, 28), dtype="float32") = R.add(lv2,
gv1)
+ R.output(gv2)
+ return gv2
+
+ _assert_test(Input, Expected, Expected2)
def test_conv2d_bias_conv2d():
@@ -524,6 +735,58 @@ def test_conv2d_bias_conv2d():
R.output(gv)
return gv
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ z: R.Tensor((1, 4, 64, 64), dtype="float32"),
+ w0: R.Tensor((512, 4, 3, 3), dtype="float16"),
+ w1: R.Tensor((512,), dtype="float16"),
+ w2: R.Tensor((4, 4, 1, 1), dtype="float16"),
+ w3: R.Tensor((4,), dtype="float16"),
+ ) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 4, 64, 64), dtype="float16") = R.astype(z,
dtype="float16")
+ lv_1: R.Tensor((512, 4, 3, 3), dtype="float16") = w0
+ lv1: R.Tensor((512,), dtype="float16") = w1
+ lv140: R.Tensor((4, 4, 1, 1), dtype="float16") = w2
+ lv141: R.Tensor((4,), dtype="float16") = w3
+ lv142: R.Tensor((1, 4, 64, 64), dtype="float16") = R.nn.conv2d(
+ lv,
+ lv140,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float16",
+ )
+ lv143: R.Tensor((1, 4, 1, 1), dtype="float16") = R.reshape(
+ lv141, R.shape([1, 4, 1, 1])
+ )
+ lv144: R.Tensor((1, 4, 64, 64), dtype="float16") =
R.add(lv142, lv143)
+ lv145: R.Tensor((1, 512, 64, 64), dtype="float16") =
R.nn.conv2d(
+ lv144,
+ lv_1,
+ strides=[1, 1],
+ padding=[1, 1, 1, 1],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float16",
+ )
+ lv146: R.Tensor((1, 512, 1, 1), dtype="float16") = R.reshape(
+ lv1, R.shape([1, 512, 1, 1])
+ )
+ lv147: R.Tensor((1, 512, 64, 64), dtype="float16") =
R.add(lv145, lv146)
+ gv: R.Tensor((1, 512, 64, 64), dtype="float32") =
R.astype(lv147, dtype="float32")
+ R.output(gv)
+ return gv
+
binding = {
"w0": np.random.uniform(size=(512, 4, 3, 3)).astype("float16"),
"w1": np.random.uniform(size=(512,)).astype("float16"),
@@ -533,7 +796,8 @@ def test_conv2d_bias_conv2d():
binding = {k: tvm.nd.array(v) for k, v in binding.items()}
Input = relax.transform.BindParams("main", binding)(Input)
Expected = relax.transform.BindParams("main", binding)(Expected)
- _assert_test(Input, Expected)
+ Expected2 = relax.transform.BindParams("main", binding)(Expected2)
+ _assert_test(Input, Expected, Expected2)
if __name__ == "__main__":