This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new ecb689e51f [Unity] [Bugfix] Fix TypeError in interpolate caused by
scale_factor as tuple (#15935)
ecb689e51f is described below
commit ecb689e51f798a3de080510d47663a21a8c0ebda
Author: Thrsu <[email protected]>
AuthorDate: Tue Oct 17 09:23:36 2023 +0800
[Unity] [Bugfix] Fix TypeError in interpolate caused by scale_factor as
tuple (#15935)
* Fix interpolate type error
* Add regression test case
* Reformat fx_translator.py
---
python/tvm/relax/frontend/torch/fx_translator.py | 8 ++++-
tests/python/relax/test_frontend_from_fx.py | 37 ++++++++++++++++++++++++
2 files changed, 44 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 7fa0358dc6..012c8328bb 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -1122,7 +1122,13 @@ class TorchFXImporter:
if size is None:
shape = self.shape_of(data)
assert isinstance(shape, relax.ShapeExpr)
- size = tuple(int(shape[i].value * scale_factor) for i in range(2,
len(shape)))
+ if isinstance(scale_factor, tuple):
+ assert len(scale_factor) == len(shape) - 2
+ size = tuple(
+ int(shape[i].value * scale_factor[i - 2]) for i in
range(2, len(shape))
+ )
+ else:
+ size = tuple(int(shape[i].value * scale_factor) for i in
range(2, len(shape)))
if method.startswith("nearest"):
method = "nearest_neighbor"
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index d7ad0d83dd..937214dca6 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -2413,6 +2413,43 @@ def test_interpolate():
verify_model(Interpolate2(), input_info, {}, expected2)
+ class Interpolate3(Module):
+ def forward(self, input):
+ return torch.nn.functional.interpolate(
+ input,
+ size=None,
+ scale_factor=(2.0, 1.0),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ @tvm.script.ir_module
+ class expected3:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 20, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 20, 10), dtype="float32") =
R.image.resize2d(
+ input_1,
+ (20, 10),
+ roi=[0.000000, 0.000000, 0.000000, 0.000000],
+ layout="NCHW",
+ method="linear",
+ coordinate_transformation_mode="half_pixel",
+ rounding_method="round",
+ cubic_alpha=-0.5,
+ cubic_exclude=0,
+ extrapolation_value=0,
+ out_dtype="",
+ )
+ gv: R.Tensor((1, 3, 20, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Interpolate3(), input_info, {}, expected3)
+
def test_addmm():
input_info = [