This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a9955e55a7 [Relax][PyTorch] Add decomposed operator support for
normalization (#18460)
a9955e55a7 is described below
commit a9955e55a7345b764db621e9c35f65451824cbd5
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sun Nov 16 14:06:17 2025 +0800
[Relax][PyTorch] Add decomposed operator support for normalization (#18460)
## Related Issue
- https://github.com/apache/tvm/pull/18401
## How
This PR
- added `_batch_norm_legit_no_stats`
- added `_native_group_norm`
- added `any.dims`
- refctored `_reshape`
---
.../frontend/torch/base_fx_graph_translator.py | 6 +++
.../frontend/torch/exported_program_translator.py | 50 ++++++++++++++++++++++
.../relax/test_frontend_from_exported_program.py | 24 ++++++-----
3 files changed, 69 insertions(+), 11 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 7b8c51895c..b03723cb91 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1848,6 +1848,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
args = self.retrieve_args(node)
x = args[0]
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
+
+ # Skip identity reshape
+ current_shape = self.shape_of(x)
+ if list(current_shape) == list(dims):
+ return x
+
return self.block_builder.emit(relax.op.reshape(x, dims))
def _reshape_as(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 2a119e111b..63aba55a78 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -113,6 +113,31 @@ class ExportedProgramImporter(BaseFXGraphImporter):
training = False
return self._batch_norm(node, training)
+ def _batch_norm_legit_no_stats(self, node: fx.Node) -> relax.Var:
+ import numpy as np
+
+ x = self.env[node.args[0]]
+ channel = int(self.shape_of(x)[1])
+ dtype = x.struct_info.dtype
+ weight = self.env.get(node.args[1], relax.const(np.ones(channel),
dtype=dtype))
+ bias = self.env.get(node.args[2], relax.const(np.zeros(channel),
dtype=dtype))
+ eps = node.args[5] if len(node.args) > 5 else node.kwargs.get("eps",
1e-05)
+
+ # Determine axes for instance norm (all spatial dimensions after
channel)
+ dim = len(self.shape_of(x))
+ axes = list(range(2, dim))
+
+ return self.block_builder.emit(
+ relax.op.nn.instance_norm(
+ x,
+ weight,
+ bias,
+ channel_axis=1,
+ axes=axes,
+ epsilon=eps,
+ )
+ )
+
def _cross_entropy_default(self, node: fx.Node) -> relax.Expr:
preds = self.env[node.args[0]]
targets = self.env[node.args[1]]
@@ -141,6 +166,28 @@ class ExportedProgramImporter(BaseFXGraphImporter):
)
)
+ def _native_group_norm(self, node: fx.Node) -> relax.Var:
+ # native_group_norm signature: (input, weight, bias, N, C, HxW, group,
eps)
+ x = self.env[node.args[0]]
+ gamma = self.env.get(node.args[1], None) if len(node.args) > 1 else
None
+ beta = self.env.get(node.args[2], None) if len(node.args) > 2 else None
+ # args[3] = N (batch size), args[4] = C (channels), args[5] = HxW
(spatial size)
+ num_groups = node.args[6] if len(node.args) > 6 else 1
+ eps = node.args[7] if len(node.args) > 7 else 1e-05
+
+ dim = len(self.shape_of(x))
+ return self.block_builder.emit(
+ relax.op.nn.group_norm(
+ x,
+ gamma,
+ beta,
+ num_groups=num_groups,
+ channel_axis=1,
+ axes=list(range(2, dim)),
+ epsilon=eps,
+ )
+ )
+
def _upsample_impl(
self,
x: relax.Expr,
@@ -963,6 +1010,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"_adaptive_avg_pool3d.default": self._adaptive_avg_pool3d,
"_native_batch_norm_legit_functional.default":
self._batch_norm_legit_functional,
"_native_batch_norm_legit_no_training.default":
self._batch_norm_legit_no_training,
+ "_native_batch_norm_legit.no_stats":
self._batch_norm_legit_no_stats,
"batch_norm.default": self._batch_norm_legit_no_training,
"adaptive_avg_pool1d.default": self._adaptive_avg_pool1d,
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
@@ -988,6 +1036,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
),
"group_norm.default": self._group_norm,
"instance_norm.default": self._instance_norm,
+ "native_group_norm.default": self._native_group_norm,
"layer_norm.default": self._layer_norm,
"linear.default": self._linear,
"lstm.input": self._lstm,
@@ -1004,6 +1053,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"upsample_bicubic2d.vec": self._upsample_bicubic2d,
# statistical
"any.dim": self._any,
+ "any.dims": self._any,
"mean.dim": self._mean,
"prod.default": self._prod,
"std.correction": self._std,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index f571ee1fd9..1b816432ce 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1514,12 +1514,10 @@ def test_isin():
x: R.Tensor((10, 10), dtype="float32"), test_elements:
R.Tensor((8,), dtype="float32")
) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
with R.dataflow():
- lv: R.Tensor((10, 10, 1), dtype="float32") = R.expand_dims(x,
axis=[-1])
- lv1: R.Tensor((8,), dtype="float32") =
R.reshape(test_elements, R.shape([8]))
- lv2: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, lv1)
- lv3: R.Tensor((10, 10), dtype="bool") = R.sum(lv2, axis=[-1],
keepdims=False)
- lv4: R.Tensor((10, 10), dtype="bool") = R.greater(lv3,
R.const(0.0, "float32"))
- gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv4,)
+ lv: R.Tensor((10, 10, 1), dtype="float32") = R.reshape(x,
R.shape([10, 10, 1]))
+ lv1: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv,
test_elements)
+ lv2: R.Tensor((10, 10), dtype="bool") = R.max(lv1, axis=[-1],
keepdims=False)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv2,)
R.output(gv)
return gv
@@ -1527,7 +1525,7 @@ def test_isin():
torch.randn(10, 10, dtype=torch.float32),
torch.randn(8, dtype=torch.float32),
)
- verify_model(IsInModel(), example_args, {}, expected)
+ verify_model(IsInModel(), example_args, {}, expected,
run_ep_decomposition=True)
def test_div_mode():
@@ -3155,7 +3153,7 @@ def test_groupnorm():
"w1": model.gn.weight.detach().numpy(),
"w2": model.gn.bias.detach().numpy(),
}
- verify_model(model, example_args, binding, expected1)
+ verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
def test_instancenorm2d():
@@ -3200,7 +3198,7 @@ def test_instancenorm2d():
"w1": torch.ones(3).detach().numpy(),
"w2": torch.zeros(3).detach().numpy(),
}
- verify_model(model, example_args, binding, expected1)
+ verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
def test_layernorm():
@@ -5556,7 +5554,9 @@ def test_unwrap_unit_return_tuple():
example_args = (torch.randn(256, 256, dtype=torch.float32),)
exported_program = export(Identity(), args=example_args)
- mod = from_exported_program(exported_program,
unwrap_unit_return_tuple=True)
+ mod = from_exported_program(
+ exported_program, unwrap_unit_return_tuple=True,
run_ep_decomposition=True
+ )
tvm.ir.assert_structural_equal(mod, Expected)
@@ -5586,7 +5586,9 @@ def test_no_bind_return_tuple():
torch.randn(256, 256, dtype=torch.float32),
)
exported_program = export(Identity(), args=example_args)
- mod = from_exported_program(exported_program, no_bind_return_tuple=True)
+ mod = from_exported_program(
+ exported_program, no_bind_return_tuple=True, run_ep_decomposition=True
+ )
tvm.ir.assert_structural_equal(mod, Expected)