This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6cf49e6ee3 [Relax][PyTorch] Enhance scale_factor handling in 
interpolation (#18550)
6cf49e6ee3 is described below

commit 6cf49e6ee3ba5209766a7aeff4000c00e7c4f58c
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sat Dec 6 17:35:27 2025 +0800

    [Relax][PyTorch] Enhance scale_factor handling in interpolation (#18550)
    
    ## Why
    
    Fixes interpolation to support different scaling factors for height and
    width (e.g., scale_factor=[2.0, 3.0])
    
    ## How
    
    - Removed the bug: Stopped extracting just the first element ([0]) from
    scale_factor lists
    - Passed full value: Now passes the entire scale_factor (scalar or list)
    to the underlying implementation, which already handles both correctly
---
 .../frontend/torch/exported_program_translator.py  | 18 ++++----
 .../relax/test_frontend_from_exported_program.py   | 51 ++++++++++++++++++++++
 2 files changed, 60 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 2ec61796c3..641e16f599 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -337,11 +337,11 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             )
 
         else:
-            # TODO figure out why pytorch export passes a list such as
-            # [scale_factor,scale_factor] instead of just an int for
-            # scale_factor. Using first element for now
+            # PyTorch export passes scale_factor as either a scalar or a 
list/tuple
+            # (e.g., [2.0, 3.0] for different H and W scaling).
+            # Pass it as-is to _upsample_impl which handles both cases 
correctly.
             scale_factor = (
-                node.args[2][0] if len(node.args) > 2 else 
node.kwargs.get("scale_factor", 1)
+                node.args[2] if len(node.args) > 2 else 
node.kwargs.get("scale_factor", 1)
             )
             align_corners = (
                 node.args[3] if len(node.args) > 3 else 
node.kwargs.get("align_corners", None)
@@ -364,11 +364,11 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         if size is not None:
             scale_factor = None
         else:
-            scale_arg = node.args[3] if len(node.args) > 3 else 
node.kwargs.get("scale_factor", 1)
-            if isinstance(scale_arg, (list, tuple)):
-                scale_factor = scale_arg[0]
-            else:
-                scale_factor = scale_arg
+            # PyTorch export passes scale_factor as either a scalar or a 
list/tuple.
+            # Pass it as-is to _upsample_impl which handles both cases 
correctly.
+            scale_factor = (
+                node.args[3] if len(node.args) > 3 else 
node.kwargs.get("scale_factor", 1)
+            )
 
         return self._upsample_impl(
             x,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 010bd026a8..68567e1fc8 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -8542,5 +8542,56 @@ def test_grid_sample():
     verify_model(GridSample(), example_args, {}, expected)
 
 
+def test_upsample_nearest2d():
+    class UpsampleNearest2dScale(Module):
+        def forward(self, input):
+            return torch.nn.functional.interpolate(input, scale_factor=2.0, 
mode="nearest")
+
+    class UpsampleNearest2dSize(Module):
+        def forward(self, input):
+            return torch.nn.functional.interpolate(input, size=(20, 20), 
mode="nearest")
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+    @tvm.script.ir_module
+    class expected_scale:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 20, 20), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 20, 20), dtype="float32") = 
R.image.resize2d(
+                    input_1,
+                    size=(20, 20),
+                    layout="NCHW",
+                    method="nearest_neighbor",
+                    coordinate_transformation_mode="half_pixel",
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 20, 20), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected_size:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 20, 20), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 20, 20), dtype="float32") = 
R.image.resize2d(
+                    input_1,
+                    size=(20, 20),
+                    layout="NCHW",
+                    method="nearest_neighbor",
+                    coordinate_transformation_mode="half_pixel",
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 20, 20), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(UpsampleNearest2dScale(), example_args, {}, expected_scale)
+    verify_model(UpsampleNearest2dSize(), example_args, {}, expected_size)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to