This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a747614a83 [Relax][PyTroch] Add NHWC layout support (#18548)
a747614a83 is described below
commit a747614a83ee665a4b0765953b0e5ff098063d5b
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sat Dec 6 02:28:44 2025 +0800
[Relax][PyTroch] Add NHWC layout support (#18548)
## Why
- The interpolate operation was hardcoded to only support NCHW layout
- Users need flexibility to choose the appropriate layout for their
target platform
## How
- Added default_image_layout parameter
- Exposed default_image_layout parameter in the public from_fx()
---
python/tvm/relax/frontend/torch/fx_translator.py | 36 +++++--
tests/python/relax/test_frontend_from_fx.py | 115 +++++++++++++++++++++++
2 files changed, 144 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 9c2d53a685..8b1f5de36b 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -33,11 +33,12 @@ class TorchFXImporter(BaseFXGraphImporter):
import torch # type: ignore
from torch import fx
- def __init__(self) -> None:
+ def __init__(self, default_image_layout: str = "NCHW") -> None:
import torch # type: ignore
super().__init__()
self.named_modules: Dict[str, torch.Module] = None
+ self.default_image_layout = default_image_layout
########## Utilities ##########
@@ -480,7 +481,6 @@ class TorchFXImporter(BaseFXGraphImporter):
# torch.nn.functional.interpolate(
# input, size=None, scale_factor=None, mode='nearest',
align_corners=None,
# recompute_scale_factor=None, antialias=False)
- # (TODO) this is a temporary implementation for interpolate that only
considers NCHW layout
data = self.env[node.args[0]]
size = (
node.args[1]
@@ -523,13 +523,26 @@ class TorchFXImporter(BaseFXGraphImporter):
if size is None:
shape = self.shape_of(data)
assert isinstance(shape, relax.ShapeExpr)
+ # Determine spatial dimension indices based on layout
+ # NCHW: spatial dims are [2, 3, ...] (skip batch and channel)
+ # NHWC: spatial dims are [1, 2, ...] (skip batch, before channel)
+ if self.default_image_layout == "NHWC":
+ spatial_start = 1
+ spatial_end = len(shape) - 1
+ else: # NCHW or other layouts
+ spatial_start = 2
+ spatial_end = len(shape)
+
if isinstance(scale_factor, tuple):
- assert len(scale_factor) == len(shape) - 2
+ assert len(scale_factor) == spatial_end - spatial_start
size = tuple(
- int(shape[i].value * scale_factor[i - 2]) for i in
range(2, len(shape))
+ int(shape[i].value * scale_factor[i - spatial_start])
+ for i in range(spatial_start, spatial_end)
)
else:
- size = tuple(int(shape[i].value * scale_factor) for i in
range(2, len(shape)))
+ size = tuple(
+ int(shape[i].value * scale_factor) for i in
range(spatial_start, spatial_end)
+ )
if method.startswith("nearest"):
method = "nearest_neighbor"
@@ -545,7 +558,11 @@ class TorchFXImporter(BaseFXGraphImporter):
return self.block_builder.emit(
relax.op.image.resize2d(
- data, size, layout="NCHW", method=method,
coordinate_transformation_mode=coord_trans
+ data,
+ size,
+ layout=self.default_image_layout,
+ method=method,
+ coordinate_transformation_mode=coord_trans,
)
)
@@ -1150,6 +1167,7 @@ def from_fx(
unwrap_unit_return_tuple: bool = False,
no_bind_return_tuple: bool = False,
custom_convert_map: dict = None,
+ default_image_layout: str = "NCHW",
) -> tvm.IRModule:
"""Convert a PyTorch FX GraphModule to a Relax program
@@ -1175,6 +1193,10 @@ def from_fx(
custom_convert_map : Dictionary of str to Relax op
A custom op conversion map in the same format as
TorchFXImporter.convert_map
+ default_image_layout : str
+ The default layout for image operations (e.g., "NCHW" or "NHWC").
+ Default is "NCHW" which is the standard PyTorch layout.
+
Returns
-------
output : tvm.IRModule
@@ -1242,7 +1264,7 @@ def from_fx(
to print out the tabular representation of the PyTorch module, and then
check the placeholder rows in the beginning of the tabular.
"""
- return TorchFXImporter().from_fx(
+ return TorchFXImporter(default_image_layout=default_image_layout).from_fx(
model,
input_info,
keep_params_as_input,
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index de30af01ee..b7aeea6687 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3670,6 +3670,121 @@ def test_interpolate():
verify_model(Interpolate4(), input_info, {}, expected4)
+def test_interpolate_nhwc_layout():
+ # First verify backward compatibility - default should still be NCHW
+ input_info_nchw = [([1, 3, 10, 10], "float32")]
+
+ class InterpolateDefault(Module):
+ def forward(self, input):
+ return torch.nn.functional.interpolate(input, (5, 5))
+
+ @tvm.script.ir_module
+ class expected_default_nchw:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 5, 5), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 5, 5), dtype="float32") = R.image.resize2d(
+ input_1,
+ (5, 5),
+ roi=[0.000000, 0.000000, 0.000000, 0.000000],
+ layout="NCHW",
+ method="nearest_neighbor",
+ coordinate_transformation_mode="asymmetric",
+ rounding_method="round",
+ cubic_alpha=-0.75,
+ cubic_exclude=0,
+ extrapolation_value=0,
+ out_dtype="",
+ )
+ gv: R.Tensor((1, 3, 5, 5), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ # Verify default behavior (no default_image_layout parameter) uses NCHW
+ graph_model_default = fx.symbolic_trace(InterpolateDefault())
+ with torch.no_grad():
+ mod_default = from_fx(graph_model_default, input_info_nchw)
+ tvm.ir.assert_structural_equal(mod_default, expected_default_nchw)
+
+ # Now test NHWC layout
+ input_info = [([1, 10, 10, 3], "float32")]
+
+ class InterpolateNHWC(Module):
+ def forward(self, input):
+ return torch.nn.functional.interpolate(input, (5, 5))
+
+ @tvm.script.ir_module
+ class expected_nhwc:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 10, 10, 3), dtype="float32")
+ ) -> R.Tensor((1, 5, 5, 3), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 5, 5, 3), dtype="float32") = R.image.resize2d(
+ input_1,
+ (5, 5),
+ roi=[0.000000, 0.000000, 0.000000, 0.000000],
+ layout="NHWC",
+ method="nearest_neighbor",
+ coordinate_transformation_mode="asymmetric",
+ rounding_method="round",
+ cubic_alpha=-0.75,
+ cubic_exclude=0,
+ extrapolation_value=0,
+ out_dtype="",
+ )
+ gv: R.Tensor((1, 5, 5, 3), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ # Test with NHWC layout
+ graph_model = fx.symbolic_trace(InterpolateNHWC())
+ with torch.no_grad():
+ mod = from_fx(graph_model, input_info, default_image_layout="NHWC")
+ tvm.ir.assert_structural_equal(mod, expected_nhwc)
+
+ # Test with bilinear interpolation and NHWC layout
+ class InterpolateNHWC2(Module):
+ def forward(self, input):
+ return torch.nn.functional.interpolate(
+ input, size=None, scale_factor=2.0, mode="bilinear",
align_corners=False
+ )
+
+ @tvm.script.ir_module
+ class expected_nhwc2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 10, 10, 3), dtype="float32")
+ ) -> R.Tensor((1, 20, 20, 3), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 20, 20, 3), dtype="float32") =
R.image.resize2d(
+ input_1,
+ (20, 20),
+ roi=[0.000000, 0.000000, 0.000000, 0.000000],
+ layout="NHWC",
+ method="linear",
+ coordinate_transformation_mode="half_pixel",
+ rounding_method="round",
+ cubic_alpha=-0.75,
+ cubic_exclude=0,
+ extrapolation_value=0,
+ out_dtype="",
+ )
+ gv: R.Tensor((1, 20, 20, 3), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ graph_model2 = fx.symbolic_trace(InterpolateNHWC2())
+ with torch.no_grad():
+ mod2 = from_fx(graph_model2, input_info, default_image_layout="NHWC")
+ tvm.ir.assert_structural_equal(mod2, expected_nhwc2)
+
+
def test_addmm():
input_info = [
([10, 10], "float32"),