This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6041e9f455 [Relax][PyTorch] Add support for antialiased bilinear 
upsampling (#18500)
6041e9f455 is described below

commit 6041e9f455ebd694ea6f2dcb755897b2a9aec9fe
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Nov 26 13:17:32 2025 +0800

    [Relax][PyTorch] Add support for antialiased bilinear upsampling (#18500)
    
    ## Related Issue
    
    closes https://github.com/apache/tvm/issues/18365
    
    ## How
    
    - add support for antialiased bilinear upsampling
---
 .../frontend/torch/exported_program_translator.py  | 17 ++++++++++
 .../relax/test_frontend_from_exported_program.py   | 37 ++++++++++++++++++++++
 2 files changed, 54 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 95b0e05361..7af8774ee3 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -298,6 +298,22 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             x, size=size, scale_factor=scale_factor, method="linear", 
align_corners=align_corners
         )
 
+    def _upsample_bilinear2d_aa(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        size = node.args[1] if len(node.args) > 1 else 
node.kwargs.get("output_size", None)
+        align_corners = (
+            node.args[2] if len(node.args) > 2 else 
node.kwargs.get("align_corners", False)
+        )
+        scale_factor = (
+            node.args[3] if len(node.args) > 3 else 
node.kwargs.get("scale_factors", None)
+        )
+
+        # Note: TVM's resize2d doesn't have explicit antialias support.
+        # For upsampling, antialiasing has minimal effect, so we use regular 
bilinear.
+        return self._upsample_impl(
+            x, size=size, scale_factor=scale_factor, method="linear", 
align_corners=align_corners
+        )
+
     def _upsample_nearest2d(self, node: fx.node) -> relax.Var:
         x = self.env[node.args[0]]
         size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", 
None)
@@ -1218,6 +1234,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "scaled_dot_product_attention.default": 
self._scaled_dot_product_attention,
             "unbind.int": self._unbind,
             "upsample_bilinear2d.vec": self._upsample_bilinear2d,
+            "_upsample_bilinear2d_aa.default": self._upsample_bilinear2d_aa,
             "upsample_nearest2d.vec": self._upsample_nearest2d,
             "upsample_bicubic2d.vec": self._upsample_bicubic2d,
             # statistical
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index d4c23bfdd5..98c6c6d014 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -4703,6 +4703,43 @@ def test_interpolate():
     verify_model(InterpolateBicubic(), example_args, {}, expected_bicubic)
 
 
+def test_interpolate_antialiased():
+    """Test bilinear interpolation with antialiasing enabled."""
+
+    class InterpolateBilinearAA(Module):
+        def forward(self, input):
+            return torch.nn.functional.interpolate(
+                input, size=(64, 64), mode="bilinear", align_corners=False, 
antialias=True
+            )
+
+    @tvm.script.ir_module
+    class expected_bilinear_aa:
+        @R.function
+        def main(
+            input: R.Tensor((1, 3, 32, 32), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 64, 64), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 64, 64), dtype="float32") = 
R.image.resize2d(
+                    input,
+                    R.shape([64, 64]),
+                    roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), 
T.float32(0.0)],
+                    layout="NCHW",
+                    method="linear",
+                    coordinate_transformation_mode="half_pixel",
+                    rounding_method="round",
+                    cubic_alpha=-0.75,
+                    cubic_exclude=0,
+                    extrapolation_value=0.0,
+                    out_dtype="void",
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 64, 64), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 32, 32, dtype=torch.float32),)
+    verify_model(InterpolateBilinearAA(), example_args, {}, 
expected_bilinear_aa)
+
+
 def test_mean():
     class Mean(Module):
         def forward(self, input):

Reply via email to