This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new b78ead8d05 [Unity][Transform] Fix AMP tests (#14360)
b78ead8d05 is described below

commit b78ead8d055cac8b990bd74de783e34001232931
Author: Bohan Hou <[email protected]>
AuthorDate: Tue Mar 21 16:51:54 2023 -0400

    [Unity][Transform] Fix AMP tests (#14360)
    
    tests
---
 .../relax/test_transform_to_mixed_precision.py     | 43 +++++++++++++++++++++-
 1 file changed, 42 insertions(+), 1 deletion(-)

diff --git a/tests/python/relax/test_transform_to_mixed_precision.py 
b/tests/python/relax/test_transform_to_mixed_precision.py
index 7b542243dc..6b699b5165 100644
--- a/tests/python/relax/test_transform_to_mixed_precision.py
+++ b/tests/python/relax/test_transform_to_mixed_precision.py
@@ -293,7 +293,48 @@ def test_conv2d_relu_conv2d():
                 R.output(gv3)
             return gv3
 
-    _assert_test(Input, Expected)
+    @I.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+            w2: R.Tensor((4, 4, 3, 3), dtype="float32"),
+        ) -> R.Tensor((2, 4, 24, 24), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x, 
dtype="float16")
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w, 
dtype="float16")
+                lv2: R.Tensor((4, 4, 3, 3), dtype="float16") = R.astype(w2, 
dtype="float16")
+                gv: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float16",
+                )
+                gv2: R.Tensor((2, 4, 26, 26), dtype="float16") = R.nn.relu(gv)
+                lv3: R.Tensor((2, 4, 24, 24), dtype="float16") = R.nn.conv2d(
+                    gv2,
+                    lv2,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float16",
+                )
+                gv3: R.Tensor((2, 4, 24, 24), dtype="float32") = R.astype(lv3, 
dtype="float32")
+                R.output(gv3)
+            return gv3
+
+    _assert_test(Input, Expected, Expected2)
 
 
 def test_gemm_add_silu():

Reply via email to