This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new ecb689e51f [Unity] [Bugfix] Fix TypeError in interpolate caused by 
scale_factor as tuple (#15935)
ecb689e51f is described below

commit ecb689e51f798a3de080510d47663a21a8c0ebda
Author: Thrsu <[email protected]>
AuthorDate: Tue Oct 17 09:23:36 2023 +0800

    [Unity] [Bugfix] Fix TypeError in interpolate caused by scale_factor as 
tuple (#15935)
    
    * Fix interpolate type error
    
    * Add regression test case
    
    * Reformat fx_translator.py
---
 python/tvm/relax/frontend/torch/fx_translator.py |  8 ++++-
 tests/python/relax/test_frontend_from_fx.py      | 37 ++++++++++++++++++++++++
 2 files changed, 44 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 7fa0358dc6..012c8328bb 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -1122,7 +1122,13 @@ class TorchFXImporter:
         if size is None:
             shape = self.shape_of(data)
             assert isinstance(shape, relax.ShapeExpr)
-            size = tuple(int(shape[i].value * scale_factor) for i in range(2, 
len(shape)))
+            if isinstance(scale_factor, tuple):
+                assert len(scale_factor) == len(shape) - 2
+                size = tuple(
+                    int(shape[i].value * scale_factor[i - 2]) for i in 
range(2, len(shape))
+                )
+            else:
+                size = tuple(int(shape[i].value * scale_factor) for i in 
range(2, len(shape)))
 
         if method.startswith("nearest"):
             method = "nearest_neighbor"
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index d7ad0d83dd..937214dca6 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -2413,6 +2413,43 @@ def test_interpolate():
 
     verify_model(Interpolate2(), input_info, {}, expected2)
 
+    class Interpolate3(Module):
+        def forward(self, input):
+            return torch.nn.functional.interpolate(
+                input,
+                size=None,
+                scale_factor=(2.0, 1.0),
+                mode="bilinear",
+                align_corners=False,
+            )
+
+    @tvm.script.ir_module
+    class expected3:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 20, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 20, 10), dtype="float32") = 
R.image.resize2d(
+                    input_1,
+                    (20, 10),
+                    roi=[0.000000, 0.000000, 0.000000, 0.000000],
+                    layout="NCHW",
+                    method="linear",
+                    coordinate_transformation_mode="half_pixel",
+                    rounding_method="round",
+                    cubic_alpha=-0.5,
+                    cubic_exclude=0,
+                    extrapolation_value=0,
+                    out_dtype="",
+                )
+                gv: R.Tensor((1, 3, 20, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Interpolate3(), input_info, {}, expected3)
+
 
 def test_addmm():
     input_info = [

Reply via email to