This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 354c5f1008 [Unity] [Bugfix] Fix bug in interpolate operator's default
mode parameter in PyTorch frontend (#15933)
354c5f1008 is described below
commit 354c5f100832733d809a37daec3fec2a4c115d06
Author: Thrsu <[email protected]>
AuthorDate: Mon Oct 16 20:29:52 2023 +0800
[Unity] [Bugfix] Fix bug in interpolate operator's default mode parameter
in PyTorch frontend (#15933)
* Fix wrong attribute name of interpolate
* Add regression test case.
* Reformat test_frontend_from_fx.py
* Update test_frontend_from_fx.py
---
python/tvm/relax/frontend/torch/fx_translator.py | 2 +-
tests/python/relax/test_frontend_from_fx.py | 37 ++++++++++++++++++++++++
2 files changed, 38 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 6062280b9d..7fa0358dc6 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -1094,7 +1094,7 @@ class TorchFXImporter:
method = (
node.args[3]
if len(node.args) > 3
- else (node.kwargs["method"] if "method" in node.kwargs else
"nearest")
+ else (node.kwargs["mode"] if "mode" in node.kwargs else "nearest")
)
align_corners = (
node.args[4]
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index a1acff4974..d7ad0d83dd 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -2376,6 +2376,43 @@ def test_interpolate():
verify_model(Interpolate(), input_info, {}, expected1)
+ class Interpolate2(Module):
+ def forward(self, input):
+ return torch.nn.functional.interpolate(
+ input,
+ size=None,
+ scale_factor=2.0,
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 20, 20), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 20, 20), dtype="float32") =
R.image.resize2d(
+ input_1,
+ (20, 20),
+ roi=[0.000000, 0.000000, 0.000000, 0.000000],
+ layout="NCHW",
+ method="linear",
+ coordinate_transformation_mode="half_pixel",
+ rounding_method="round",
+ cubic_alpha=-0.5,
+ cubic_exclude=0,
+ extrapolation_value=0,
+ out_dtype="",
+ )
+ gv: R.Tensor((1, 3, 20, 20), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Interpolate2(), input_info, {}, expected2)
+
def test_addmm():
input_info = [