This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 914590a22d [Relax][PyTorch] Add tests for all the dtypes supported in 
the PyTorch frontend (#17926)
914590a22d is described below

commit 914590a22d4d43a7c23fb5d6447dd55ca4895ee7
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Thu May 8 11:45:49 2025 +0900

    [Relax][PyTorch] Add tests for all the dtypes supported in the PyTorch 
frontend (#17926)
    
    refactor dtype tests to support multiple data types in pytorch frontends
---
 .../relax/test_frontend_from_exported_program.py   | 34 ++++++++++++++--------
 tests/python/relax/test_frontend_from_fx.py        | 28 ++++++++++++------
 2 files changed, 41 insertions(+), 21 deletions(-)

diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index ab3826b935..c375992dca 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5116,31 +5116,41 @@ def test_linspace():
     verify_model(Linspace(), example_args, {}, Expected)
 
 
-def test_bfloat16():
-    # TODO(mshr-h): Add tests for all the dtypes supported in fx frontend
[email protected](
+    "torch_dtype, relax_dtype",
+    [
+        (torch.float32, "float32"),
+        (torch.float16, "float16"),
+        (torch.bfloat16, "bfloat16"),
+        (torch.int64, "int64"),
+        (torch.int32, "int32"),
+        (torch.bool, "bool"),
+    ],
+)
+def test_dtypes(torch_dtype, relax_dtype):
     example_args = (
-        torch.randn(10, 10, dtype=torch.bfloat16),
-        torch.randn(10, 10, dtype=torch.bfloat16),
+        torch.randint(0, 10, (10, 10)).to(torch_dtype),
+        torch.randint(0, 10, (10, 10)).to(torch_dtype),
     )
 
-    class BFloat16Model(Module):
+    class Model(Module):
         def forward(self, lhs: torch.Tensor, rhs: torch.Tensor):
             return torch.ops.aten.add(lhs, rhs)
 
     @tvm.script.ir_module
-    class expected:
+    class Expected:
         @R.function
         def main(
-            lhs: R.Tensor((10, 10), dtype="bfloat16"),
-            rhs: R.Tensor((10, 10), dtype="bfloat16"),
-        ) -> R.Tuple(R.Tensor((10, 10), dtype="bfloat16")):
+            lhs: R.Tensor((10, 10), dtype=relax_dtype),
+            rhs: R.Tensor((10, 10), dtype=relax_dtype),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype=relax_dtype)):
             with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="bfloat16") = relax.op.add(lhs, 
rhs)
-                gv: R.Tuple(R.Tensor((10, 10), dtype="bfloat16")) = (lv,)
+                lv: R.Tensor((10, 10), dtype=relax_dtype) = relax.op.add(lhs, 
rhs)
+                gv: R.Tuple(R.Tensor((10, 10), dtype=relax_dtype)) = (lv,)
                 R.output(gv)
             return gv
 
-    verify_model(BFloat16Model(), example_args, {}, expected)
+    verify_model(Model(), example_args, {}, Expected)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 705181a024..643372750b 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -5495,9 +5495,19 @@ def test_norm():
         verify_model(Norm(p, dim=dim, keepdim=keepdim), input_info, {}, 
expected)
 
 
-def test_bfloat16():
-    # TODO(mshr-h): Add tests for all the dtypes supported in EP frontend
-    class BFloat16Model(Module):
[email protected](
+    "torch_dtype, relax_dtype",
+    [
+        (torch.float32, "float32"),
+        (torch.float16, "float16"),
+        (torch.bfloat16, "bfloat16"),
+        (torch.int64, "int64"),
+        (torch.int32, "int32"),
+        (torch.bool, "bool"),
+    ],
+)
+def test_dtypes(torch_dtype, relax_dtype):
+    class Model(Module):
         def forward(self, lhs: torch.Tensor, rhs: torch.Tensor):
             return torch.ops.aten.add(lhs, rhs)
 
@@ -5505,16 +5515,16 @@ def test_bfloat16():
     class Expected:
         @R.function
         def main(
-            lhs: R.Tensor((10, 10), dtype="bfloat16"),
-            rhs: R.Tensor((10, 10), dtype="bfloat16"),
-        ) -> R.Tensor((10, 10), dtype="bfloat16"):
+            lhs: R.Tensor((10, 10), dtype=relax_dtype),
+            rhs: R.Tensor((10, 10), dtype=relax_dtype),
+        ) -> R.Tensor((10, 10), dtype=relax_dtype):
             with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="bfloat16") = relax.op.add(lhs, 
rhs)
-                gv: R.Tensor((10, 10), dtype="bfloat16") = lv
+                lv: R.Tensor((10, 10), dtype=relax_dtype) = relax.op.add(lhs, 
rhs)
+                gv: R.Tensor((10, 10), dtype=relax_dtype) = lv
                 R.output(gv)
             return gv
 
-    verify_model(BFloat16Model(), [([10, 10], "bfloat16"), ([10, 10], 
"bfloat16")], {}, Expected)
+    verify_model(Model(), [([10, 10], torch_dtype), ([10, 10], torch_dtype)], 
{}, Expected)
 
 
 def test_eye():

Reply via email to