This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dcb5a3a8c9 [Relax][PyTorch] Add UpSample Bicubic Op Support for
Exported Program and FX graph (#17932)
dcb5a3a8c9 is described below
commit dcb5a3a8c91dcd4e081f59da112937f438c28733
Author: Deivanayaki S <[email protected]>
AuthorDate: Sat May 10 10:43:49 2025 +0530
[Relax][PyTorch] Add UpSample Bicubic Op Support for Exported Program and
FX graph (#17932)
* add upsample bicubic op support into torch frontend
* fix cubic alpha value for all interpolate func
* fix cubic alpha values in all test script
* update the mapping code in frontend
* fix lint issue
---------
Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
.../frontend/torch/exported_program_translator.py | 24 ++++++++++++
python/tvm/relax/op/image/image.py | 2 +-
python/tvm/topi/image/resize.py | 6 +--
.../relax/test_frontend_from_exported_program.py | 34 ++++++++++++++++-
tests/python/relax/test_frontend_from_fx.py | 43 ++++++++++++++++++++--
tests/python/relax/test_frontend_nn_op.py | 2 +-
.../python/relax/test_transform_convert_layout.py | 4 +-
7 files changed, 103 insertions(+), 12 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index fc37fd3fb9..87508c9fea 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -208,6 +208,29 @@ class ExportedProgramImporter(BaseFXGraphImporter):
align_corners=align_corners,
)
+ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var:
+ x = self.env[node.args[0]]
+ size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size",
None)
+ align_corners = (
+ node.args[2] if len(node.args) > 2 else
node.kwargs.get("align_corners", None)
+ )
+ if size is not None:
+ scale_factor = None
+ else:
+ scale_arg = node.args[3] if len(node.args) > 3 else
node.kwargs.get("scale_factor", 1)
+ if isinstance(scale_arg, (list, tuple)):
+ scale_factor = scale_arg[0]
+ else:
+ scale_factor = scale_arg
+
+ return self._upsample_impl(
+ x,
+ size=size,
+ scale_factor=scale_factor,
+ method="cubic",
+ align_corners=align_corners,
+ )
+
########## Manipulation ##########
def _narrow(self, node: fx.Node) -> relax.Var:
@@ -428,6 +451,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"unbind.int": self._unbind,
"upsample_bilinear2d.vec": self._upsample_bilinear2d,
"upsample_nearest2d.vec": self._upsample_nearest2d,
+ "upsample_bicubic2d.vec": self._upsample_bicubic2d,
# statistical
"mean.dim": self._mean,
"prod.default": self._prod,
diff --git a/python/tvm/relax/op/image/image.py
b/python/tvm/relax/op/image/image.py
index e314e9b49a..6bec22161d 100644
--- a/python/tvm/relax/op/image/image.py
+++ b/python/tvm/relax/op/image/image.py
@@ -35,7 +35,7 @@ def resize2d(
method: str = "linear",
coordinate_transformation_mode: str = "half_pixel",
rounding_method: str = "round",
- cubic_alpha: float = -0.5,
+ cubic_alpha: float = -0.75,
cubic_exclude: int = 0,
extrapolation_value: float = 0.0,
out_dtype: Optional[Union[str, DataType]] = None,
diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py
index 5cbc292adb..ad2c99fa3a 100644
--- a/python/tvm/topi/image/resize.py
+++ b/python/tvm/topi/image/resize.py
@@ -376,7 +376,7 @@ def resize1d(
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="",
- bicubic_alpha=-0.5,
+ bicubic_alpha=-0.75,
bicubic_exclude=0,
extrapolation_value=0.0,
out_dtype=None,
@@ -748,7 +748,7 @@ def resize2d(
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="",
- bicubic_alpha=-0.5,
+ bicubic_alpha=-0.75,
bicubic_exclude=0,
extrapolation_value=0.0,
out_dtype=None,
@@ -1217,7 +1217,7 @@ def resize3d(
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="",
- bicubic_alpha=-0.5,
+ bicubic_alpha=-0.75,
bicubic_exclude=0,
extrapolation_value=0.0,
out_dtype=None,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index b0aebff704..75c745a213 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3127,7 +3127,7 @@ def test_interpolate():
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="round",
- cubic_alpha=-0.5,
+ cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0.0,
out_dtype="void",
@@ -3156,7 +3156,36 @@ def test_interpolate():
method="nearest_neighbor",
coordinate_transformation_mode="half_pixel",
rounding_method="round",
- cubic_alpha=-0.5,
+ cubic_alpha=-0.75,
+ cubic_exclude=0,
+ extrapolation_value=0.0,
+ out_dtype="void",
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) =
(lv,)
+ R.output(gv)
+ return gv
+
+ class InterpolateBicubic(Module):
+ def forward(self, input):
+ return torch.nn.functional.interpolate(input, (224, 224),
mode="bicubic")
+
+ @tvm.script.ir_module
+ class expected_bicubic:
+ @R.function
+ def main(
+ input: R.Tensor((1, 3, 112, 112), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 224, 224), dtype="float32") =
R.image.resize2d(
+ input,
+ R.shape([224, 224]),
+ roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0),
T.float32(0.0)],
+ layout="NCHW",
+ method="cubic",
+ coordinate_transformation_mode="half_pixel",
+ rounding_method="round",
+ cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0.0,
out_dtype="void",
@@ -3168,6 +3197,7 @@ def test_interpolate():
example_args = (torch.randn(1, 3, 112, 112, dtype=torch.float32),)
verify_model(InterpolateBilinear(), example_args, {}, expected_bilinear)
verify_model(InterpolateNearest(), example_args, {}, expected_nearest)
+ verify_model(InterpolateBicubic(), example_args, {}, expected_bicubic)
def test_mean():
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 53efab4e80..48e12dfe49 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3446,7 +3446,7 @@ def test_interpolate():
method="nearest_neighbor",
coordinate_transformation_mode="asymmetric",
rounding_method="round",
- cubic_alpha=-0.5,
+ cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
@@ -3483,7 +3483,7 @@ def test_interpolate():
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="round",
- cubic_alpha=-0.5,
+ cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
@@ -3520,7 +3520,7 @@ def test_interpolate():
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="round",
- cubic_alpha=-0.5,
+ cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
@@ -3531,6 +3531,43 @@ def test_interpolate():
verify_model(Interpolate3(), input_info, {}, expected3)
+ class Interpolate4(Module):
+ def forward(self, input):
+ return torch.nn.functional.interpolate(
+ input,
+ size=None,
+ scale_factor=(2.0, 1.0),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ @tvm.script.ir_module
+ class expected4:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 20, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 20, 10), dtype="float32") =
R.image.resize2d(
+ input_1,
+ (20, 10),
+ roi=[0.000000, 0.000000, 0.000000, 0.000000],
+ layout="NCHW",
+ method="cubic",
+ coordinate_transformation_mode="half_pixel",
+ rounding_method="round",
+ cubic_alpha=-0.75,
+ cubic_exclude=0,
+ extrapolation_value=0,
+ out_dtype="",
+ )
+ gv: R.Tensor((1, 3, 20, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Interpolate4(), input_info, {}, expected4)
+
def test_addmm():
input_info = [
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
index 483e48217d..1af13f0487 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -305,7 +305,7 @@ def test_image():
method="nearest_neighbor",
coordinate_transformation_mode="asymmetric",
rounding_method="round",
- cubic_alpha=-0.5,
+ cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="void",
diff --git a/tests/python/relax/test_transform_convert_layout.py
b/tests/python/relax/test_transform_convert_layout.py
index db4130f947..262e37b91b 100644
--- a/tests/python/relax/test_transform_convert_layout.py
+++ b/tests/python/relax/test_transform_convert_layout.py
@@ -1434,7 +1434,7 @@ def test_conv2d_resize2d():
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="round",
- cubic_alpha=-0.5,
+ cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="void",
@@ -1477,7 +1477,7 @@ def test_resize2d_conv2d():
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="round",
- cubic_alpha=-0.5,
+ cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="void",