This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 914590a22d [Relax][PyTorch] Add tests for all the dtypes supported in
the PyTorch frontend (#17926)
914590a22d is described below
commit 914590a22d4d43a7c23fb5d6447dd55ca4895ee7
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Thu May 8 11:45:49 2025 +0900
[Relax][PyTorch] Add tests for all the dtypes supported in the PyTorch
frontend (#17926)
refactor dtype tests to support multiple data types in pytorch frontends
---
.../relax/test_frontend_from_exported_program.py | 34 ++++++++++++++--------
tests/python/relax/test_frontend_from_fx.py | 28 ++++++++++++------
2 files changed, 41 insertions(+), 21 deletions(-)
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index ab3826b935..c375992dca 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5116,31 +5116,41 @@ def test_linspace():
verify_model(Linspace(), example_args, {}, Expected)
-def test_bfloat16():
- # TODO(mshr-h): Add tests for all the dtypes supported in fx frontend
[email protected](
+ "torch_dtype, relax_dtype",
+ [
+ (torch.float32, "float32"),
+ (torch.float16, "float16"),
+ (torch.bfloat16, "bfloat16"),
+ (torch.int64, "int64"),
+ (torch.int32, "int32"),
+ (torch.bool, "bool"),
+ ],
+)
+def test_dtypes(torch_dtype, relax_dtype):
example_args = (
- torch.randn(10, 10, dtype=torch.bfloat16),
- torch.randn(10, 10, dtype=torch.bfloat16),
+ torch.randint(0, 10, (10, 10)).to(torch_dtype),
+ torch.randint(0, 10, (10, 10)).to(torch_dtype),
)
- class BFloat16Model(Module):
+ class Model(Module):
def forward(self, lhs: torch.Tensor, rhs: torch.Tensor):
return torch.ops.aten.add(lhs, rhs)
@tvm.script.ir_module
- class expected:
+ class Expected:
@R.function
def main(
- lhs: R.Tensor((10, 10), dtype="bfloat16"),
- rhs: R.Tensor((10, 10), dtype="bfloat16"),
- ) -> R.Tuple(R.Tensor((10, 10), dtype="bfloat16")):
+ lhs: R.Tensor((10, 10), dtype=relax_dtype),
+ rhs: R.Tensor((10, 10), dtype=relax_dtype),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype=relax_dtype)):
with R.dataflow():
- lv: R.Tensor((10, 10), dtype="bfloat16") = relax.op.add(lhs,
rhs)
- gv: R.Tuple(R.Tensor((10, 10), dtype="bfloat16")) = (lv,)
+ lv: R.Tensor((10, 10), dtype=relax_dtype) = relax.op.add(lhs,
rhs)
+ gv: R.Tuple(R.Tensor((10, 10), dtype=relax_dtype)) = (lv,)
R.output(gv)
return gv
- verify_model(BFloat16Model(), example_args, {}, expected)
+ verify_model(Model(), example_args, {}, Expected)
if __name__ == "__main__":
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 705181a024..643372750b 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -5495,9 +5495,19 @@ def test_norm():
verify_model(Norm(p, dim=dim, keepdim=keepdim), input_info, {},
expected)
-def test_bfloat16():
- # TODO(mshr-h): Add tests for all the dtypes supported in EP frontend
- class BFloat16Model(Module):
[email protected](
+ "torch_dtype, relax_dtype",
+ [
+ (torch.float32, "float32"),
+ (torch.float16, "float16"),
+ (torch.bfloat16, "bfloat16"),
+ (torch.int64, "int64"),
+ (torch.int32, "int32"),
+ (torch.bool, "bool"),
+ ],
+)
+def test_dtypes(torch_dtype, relax_dtype):
+ class Model(Module):
def forward(self, lhs: torch.Tensor, rhs: torch.Tensor):
return torch.ops.aten.add(lhs, rhs)
@@ -5505,16 +5515,16 @@ def test_bfloat16():
class Expected:
@R.function
def main(
- lhs: R.Tensor((10, 10), dtype="bfloat16"),
- rhs: R.Tensor((10, 10), dtype="bfloat16"),
- ) -> R.Tensor((10, 10), dtype="bfloat16"):
+ lhs: R.Tensor((10, 10), dtype=relax_dtype),
+ rhs: R.Tensor((10, 10), dtype=relax_dtype),
+ ) -> R.Tensor((10, 10), dtype=relax_dtype):
with R.dataflow():
- lv: R.Tensor((10, 10), dtype="bfloat16") = relax.op.add(lhs,
rhs)
- gv: R.Tensor((10, 10), dtype="bfloat16") = lv
+ lv: R.Tensor((10, 10), dtype=relax_dtype) = relax.op.add(lhs,
rhs)
+ gv: R.Tensor((10, 10), dtype=relax_dtype) = lv
R.output(gv)
return gv
- verify_model(BFloat16Model(), [([10, 10], "bfloat16"), ([10, 10],
"bfloat16")], {}, Expected)
+ verify_model(Model(), [([10, 10], torch_dtype), ([10, 10], torch_dtype)],
{}, Expected)
def test_eye():