This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1ef153fd1c [Relax][PyTorch] Refactor norm op for ExportedProgram
importer (#17857)
1ef153fd1c is described below
commit 1ef153fd1c2346d7fc46cc45483b6582927aefcd
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Apr 23 10:59:28 2025 +0800
[Relax][PyTorch] Refactor norm op for ExportedProgram importer (#17857)
* Update base_fx_graph_translator.py
* Update fx_translator.py
* Update exported_program_translator.py
* Update exported_program_translator.py
* Update fx_translator.py
* Update test_frontend_from_fx.py
* Update test_frontend_from_exported_program.py
* Update fx_translator.py
* Update exported_program_translator.py
* Update test_frontend_from_exported_program.py
* Update test_frontend_from_fx.py
* Update fx_translator.py
* Update exported_program_translator.py
* Update base_fx_graph_translator.py
* Update exported_program_translator.py
* Update exported_program_translator.py
* Update fx_translator.py
* Update fx_translator.py
* Update base_fx_graph_translator.py
* Update exported_program_translator.py
---
.../frontend/torch/base_fx_graph_translator.py | 50 ++-------
.../frontend/torch/exported_program_translator.py | 2 +-
.../relax/test_frontend_from_exported_program.py | 112 +++++++++++++++++++++
tests/python/relax/test_frontend_from_fx.py | 19 ++--
4 files changed, 131 insertions(+), 52 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 20556167c1..a89726495e 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -417,42 +417,6 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return self.block_builder.emit(relax.op.subtract(rhs, lhs))
- ########## Linear Algebra ##########
-
- def _linalg_vector_norm(self, node: fx.Node) -> relax.Var:
-
- args = self.retrieve_args(node)
-
- data = args[0]
- # Default ord=2 if not supplied
- ord_val = args[1] if len(args) > 1 else 2.0
- dim = args[2] if len(args) > 2 else None
- keepdim = args[3] if len(args) > 3 else False
-
- # If ord_val is a Python float/int, wrap it in a Relax const
- # so that it matches data's dtype.
- dtype = data.struct_info.dtype
- ord_expr = (
- ord_val if isinstance(ord_val, relax.Expr) else
relax.const(float(ord_val), dtype)
- )
- # Reciprocal
- reci_expr = (
- relax.op.divide(relax.const(1.0, dtype), ord_expr)
- if isinstance(ord_val, relax.Expr)
- else relax.const(1.0 / float(ord_val), dtype)
- )
-
- # abs(data)
- abs_data = self.block_builder.emit(relax.op.abs(data))
- # abs_data^ord
- abs_data_pow = self.block_builder.emit(relax.op.power(abs_data,
ord_expr))
- # sum over dim
- reduced = self.block_builder.emit(relax.op.sum(abs_data_pow, dim,
keepdims=keepdim))
- # (sum(...))^(1/ord)
- norm_val = self.block_builder.emit(relax.op.power(reduced, reci_expr))
-
- return norm_val
-
########## Neural Network ##########
def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
@@ -980,16 +944,22 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
elif order == "fro":
return self.block_builder.emit(
relax.op.sqrt(
- relax.op.sum(relax.op.multiply(data, data), axis=axis,
keepdims=keepdims),
+ relax.op.sum(relax.op.multiply(data, data), axis=axis,
keepdims=keepdims)
)
)
else:
- reci_order = relax.const(1 / order, dtype=dtype)
- order = relax.const(order, dtype=dtype)
+ ord_expr = (
+ order if isinstance(order, relax.Expr) else
relax.const(float(order), dtype=dtype)
+ )
+ reci_order = (
+ relax.op.divide(relax.const(1.0, dtype), ord_expr)
+ if isinstance(order, relax.Expr)
+ else relax.const(1.0 / order, dtype=dtype)
+ )
return self.block_builder.emit(
relax.op.power(
relax.op.sum(
- relax.op.power(relax.op.abs(data), order), axis=axis,
keepdims=keepdims
+ relax.op.power(relax.op.abs(data), ord_expr),
axis=axis, keepdims=keepdims
),
reci_order,
)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index f38f353a9e..cdf0c46bb5 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -369,7 +369,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"__xor__.Tensor": self._binary_op(relax.op.bitwise_xor,
operator.xor),
"__xor__.Scalar": self._binary_op(relax.op.bitwise_xor,
operator.xor),
# linear algebra
- "linalg_vector_norm.default": self._linalg_vector_norm,
+ "linalg_vector_norm.default": self._norm,
# neural network
"_native_batch_norm_legit_functional.default":
self._batch_norm_legit_functional,
"_native_batch_norm_legit_no_training.default":
self._batch_norm_legit_no_training,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index a386a989f0..c6ead5aacc 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -4379,6 +4379,118 @@ def test_narrow():
verify_model(Narrow(), example_args, {}, Expected)
+def test_norm():
+ class Norm(Module):
+ def __init__(self, p, dim=None, keepdim=False):
+ super().__init__()
+ self.p = p
+ self.dim = dim
+ self.keepdim = keepdim
+
+ def forward(self, x):
+ return torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim)
+
+ @tvm.script.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="float32") = R.max(R.abs(inp_0),
axis=None, keepdims=False)
+ gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="float32") = R.min(R.abs(inp_0),
axis=None, keepdims=False)
+ gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected3:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
+ lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv,
R.const(2, "float32"))
+ lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None,
keepdims=False)
+ lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(0.5,
"float32"))
+ gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv3,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected4:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
+ lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv,
R.const(1.0, "float32"))
+ lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None,
keepdims=False)
+ lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(1.0,
"float32"))
+ gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv3,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected5:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
+ lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv,
R.const(-4.0, "float32"))
+ lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.sum(lv1,
axis=None, keepdims=True)
+ lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.power(
+ lv2, R.const(-0.25, "float32")
+ )
+ gv: R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")) = (lv3,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected6:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
+ lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv,
R.const(0.5, "float32"))
+ lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.sum(lv1,
axis=None, keepdims=True)
+ lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.power(lv2,
R.const(2.0, "float32"))
+ gv: R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")) = (lv3,)
+ R.output(gv)
+ return gv
+
+ norms = [
+ ((float("inf"), None, False), Expected1),
+ ((float("-inf"), None, False), Expected2),
+ ((float(2), None, False), Expected3),
+ ((float(1.0), None, False), Expected4),
+ ((float(-4), None, True), Expected5),
+ ((float(0.5), None, True), Expected6),
+ ]
+
+ example_args = (torch.randn(1, 3, 5, 3, dtype=torch.float32),)
+
+ for (p, dim, keepdim), expected in norms:
+ verify_model(Norm(p, dim=dim, keepdim=keepdim), example_args, {},
expected)
+
+
def test_eye():
class Eye1(Module):
def forward(self, input):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 2d27fa1f59..f21cde6df2 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4947,19 +4947,16 @@ def test_norm():
return gv
norms = [
- (float("inf"), None, False),
- (float("-inf"), None, False),
- (float(2), None, False),
- (float(1.0), None, False),
- (float(-4), None, True),
- (float(0.5), None, True),
- ("fro", None, False),
+ ((float("inf"), None, False), Expected1),
+ ((float("-inf"), None, False), Expected2),
+ ((float(2), None, False), Expected3),
+ ((float(1.0), None, False), Expected4),
+ ((float(-4), None, True), Expected5),
+ ((float(0.5), None, True), Expected6),
+ (("fro", None, False), Expected7),
]
- for norm, expected in zip(
- norms, [Expected1, Expected2, Expected3, Expected4, Expected5,
Expected6, Expected7]
- ):
- p, dim, keepdim = norm
+ for (p, dim, keepdim), expected in norms:
verify_model(Norm(p, dim=dim, keepdim=keepdim), input_info, {},
expected)