This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f08165680e [Relax][PyTorch] Add support for norm op (#17841)
f08165680e is described below
commit f08165680e8b8696679d331e1f547de8839145cd
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Apr 17 11:23:33 2025 +0800
[Relax][PyTorch] Add support for norm op (#17841)
* Update fx_translator.py
* Update base_fx_graph_translator.py
* Update test_frontend_from_fx.py
* Update base_fx_graph_translator.py
* Update test_frontend_from_fx.py
* Update base_fx_graph_translator.py
* Update test_frontend_from_fx.py
* Update test_frontend_from_fx.py
* Update test_frontend_from_fx.py
* Update test_frontend_from_fx.py
* Update test_frontend_from_fx.py
* Update test_frontend_from_fx.py
* Update test_frontend_from_fx.py
* Update test_frontend_from_fx.py
* Update test_frontend_from_fx.py
* Update test_frontend_from_fx.py
* Update test_frontend_from_fx.py
* Update test_frontend_from_fx.py
* Update test_frontend_from_fx.py
* Update test_frontend_from_fx.py
---
.../frontend/torch/base_fx_graph_translator.py | 34 ++++++
python/tvm/relax/frontend/torch/fx_translator.py | 1 +
tests/python/relax/test_frontend_from_fx.py | 129 +++++++++++++++++++++
3 files changed, 164 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 6d880ab90d..4c9480b587 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -927,6 +927,40 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
keepdim = args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim))
+ def _norm(self, node: fx.Node) -> relax.Var:
+ data = self.env[node.args[0]]
+ dtype = data.struct_info.dtype
+ order = node.args[1] if len(node.args) > 1 else node.kwargs.get("p", 2)
+ axis = node.args[2] if len(node.args) > 2 else None
+ keepdims = node.args[3] if len(node.args) > 3 else False
+
+ if order == float("inf"):
+ return self.block_builder.emit(
+ relax.op.max(relax.op.abs(data), axis=axis, keepdims=keepdims)
+ )
+ elif order == float("-inf"):
+ return self.block_builder.emit(
+ relax.op.min(relax.op.abs(data), axis=axis, keepdims=keepdims)
+ )
+ # frobenius_norm
+ elif order == "fro":
+ return self.block_builder.emit(
+ relax.op.sqrt(
+ relax.op.sum(relax.op.multiply(data, data), axis=axis,
keepdims=keepdims),
+ )
+ )
+ else:
+ reci_order = relax.const(1 / order, dtype=dtype)
+ order = relax.const(order, dtype=dtype)
+ return self.block_builder.emit(
+ relax.op.power(
+ relax.op.sum(
+ relax.op.power(relax.op.abs(data), order), axis=axis,
keepdims=keepdims
+ ),
+ reci_order,
+ )
+ )
+
def _prod(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 594344fef8..297529e8bf 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -728,6 +728,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"lerp": self._lerp,
# statistical
"mean": self._mean,
+ "norm": self._norm,
"prod": self._prod,
"std": self._std,
"sum": self._sum,
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index ee5a5c78c7..a962de8a32 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4513,5 +4513,134 @@ def test_narrow():
verify_model(Narrow(), [([5, 3], "float32")], {}, Expected)
+def test_norm():
+
+ input_info = [([1, 3, 5, 3], "float32")]
+
+ class Norm(Module):
+ def __init__(self, p, dim=None, keepdim=False):
+ super().__init__()
+ self.p = p
+ self.dim = dim
+ self.keepdim = keepdim
+
+ def forward(self, x):
+ return torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim)
+
+ @tvm.script.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+ ) -> R.Tensor((), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="float32") = R.max(R.abs(inp_0),
axis=None, keepdims=False)
+ gv: R.Tensor((), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+ ) -> R.Tensor((), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="float32") = R.min(R.abs(inp_0),
axis=None, keepdims=False)
+ gv: R.Tensor((), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected3:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+ ) -> R.Tensor((), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
+ lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv,
R.const(2, "float32"))
+ lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None,
keepdims=False)
+ lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(0.5,
"float32"))
+ gv: R.Tensor((), dtype="float32") = lv3
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected4:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+ ) -> R.Tensor((), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
+ lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv,
R.const(1.0, "float32"))
+ lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None,
keepdims=False)
+ lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(1.0,
"float32"))
+ gv: R.Tensor((), dtype="float32") = lv3
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected5:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+ ) -> R.Tensor((), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
+ lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv,
R.const(-4, "float32"))
+ lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None,
keepdims=False)
+ lv3: R.Tensor((), dtype="float32") = R.power(lv2,
R.const(-0.25, "float32"))
+ gv: R.Tensor((), dtype="float32") = lv3
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected6:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+ ) -> R.Tensor((), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
+ lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv,
R.const(0.5, "float32"))
+ lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None,
keepdims=False)
+ lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(2,
"float32"))
+ gv: R.Tensor((), dtype="float32") = lv3
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected7:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+ ) -> R.Tensor((), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 5, 3), dtype="float32") =
R.multiply(inp_0, inp_0)
+ lv1: R.Tensor((), dtype="float32") = R.sum(lv, axis=None,
keepdims=False)
+ lv2: R.Tensor((), dtype="float32") = R.sqrt(lv1)
+ gv: R.Tensor((), dtype="float32") = lv2
+ R.output(gv)
+ return gv
+
+ norms = [
+ (float("inf"), None, False),
+ (float("-inf"), None, False),
+ (float(2), None, False),
+ (float(1.0), None, False),
+ (float(-4), None, True),
+ (float(0.5), None, True),
+ ("fro", None, False),
+ ]
+
+ for norm, expected in zip(
+ norms, [Expected1, Expected2, Expected3, Expected4, Expected5,
Expected6, Expected7]
+ ):
+ p, dim, keepdim = norm
+ verify_model(Norm(p, dim=dim, keepdim=keepdim), input_info, {},
expected)
+
+
if __name__ == "__main__":
tvm.testing.main()