This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new aafb0db251 [Relax][PyTorch] Add RSub Op Support for Exported Program
and FX graph (#17849)
aafb0db251 is described below
commit aafb0db251786145b6592a4cc8ca2ca47007c44d
Author: Deivanayaki S <[email protected]>
AuthorDate: Fri Apr 18 07:27:55 2025 +0530
[Relax][PyTorch] Add RSub Op Support for Exported Program and FX graph
(#17849)
* add rsub op support into exported and fx graph frontend
* fix trailing whitespace issue
* fix lint issues in test scripts
---------
Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
.../frontend/torch/base_fx_graph_translator.py | 10 ++++++
.../frontend/torch/exported_program_translator.py | 2 ++
python/tvm/relax/frontend/torch/fx_translator.py | 1 +
.../relax/test_frontend_from_exported_program.py | 37 ++++++++++++++++++++
tests/python/relax/test_frontend_from_fx.py | 39 ++++++++++++++++++++++
5 files changed, 89 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 7b380f9876..a9bee11fc8 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -407,6 +407,16 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return convert
+ def _rsub(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ lhs = args[0]
+ rhs = args[1]
+
+ if isinstance(rhs, (int, float)):
+ rhs = relax.const(rhs)
+
+ return self.block_builder.emit(relax.op.subtract(rhs, lhs))
+
########## Linear Algebra ##########
def _linalg_vector_norm(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index a6f9cafa65..4084e35de5 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -305,6 +305,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"relu_.default": self._unary_op(relax.op.nn.relu),
"round.default": self._round,
"rsqrt.default": self._unary_op(relax.op.rsqrt),
+ "rsub.Tensor": self._rsub,
+ "rsub.Scalar": self._rsub,
"selu.default": self._unary_op(relax.op.nn.selu),
"sigmoid.default": self._unary_op(relax.op.sigmoid),
"sign.default": self._unary_op(relax.op.sign),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 534d398bea..4ef0b05aca 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -692,6 +692,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"pow": self._binary_op(relax.op.power, operator.pow),
"or_": self._binary_op(relax.op.bitwise_or, operator.or_),
"rshift": self._binary_op(relax.op.right_shift, operator.rshift),
+ "rsub": self._rsub,
"sub": self._binary_op(relax.op.subtract, operator.sub),
"truediv": self._binary_op(relax.op.divide, operator.truediv),
"xor": self._binary_op(relax.op.bitwise_xor, operator.xor),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index e78bd339d2..7c47832ea9 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -935,6 +935,7 @@ def test_binary3():
torch.randn(10, 10, dtype=torch.float32),
torch.randn(10, 10, dtype=torch.float32),
)
+ example_args2 = (torch.randn(10, 10, dtype=torch.float32),)
# Max
class Max1(Module):
@@ -976,6 +977,42 @@ def test_binary3():
verify_model(Min1(), example_args1, {}, expected_min1)
+ # RSub
+ class RSub1(Module):
+ def forward(self, x, y):
+ return torch.rsub(x, y)
+
+ class RSub2(Module):
+ def forward(self, x):
+ return torch.rsub(x, 5.0)
+
+ @tvm.script.ir_module
+ class expected_rsub1:
+ @R.function
+ def main(
+ x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10),
dtype="float32")
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.subtract(y, x)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected_rsub2:
+ @R.function
+ def main(
+ x: R.Tensor((10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") =
R.subtract(R.const(5.0, "float32"), x)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(RSub1(), example_args1, {}, expected_rsub1)
+ verify_model(RSub2(), example_args2, {}, expected_rsub2)
+
def test_batchnorm2d():
class BatchNorm2d(Module):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index caecce4979..a2169afd0f 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1738,6 +1738,45 @@ def test_binary3(op, relax_op):
verify_model(Binary2(op), input_info2, {}, expected_binary2)
+# RSub
+def test_rsub():
+ input_info1 = [([10, 10], "float32"), ([10, 10], "float32")]
+ input_info2 = [([10, 10], "float32")]
+
+ class RSub1(Module):
+ def forward(self, x, y):
+ return torch.rsub(x, y)
+
+ class RSub2(Module):
+ def forward(self, x):
+ return torch.rsub(x, 5.0)
+
+ @tvm.script.ir_module
+ class expected_rsub1:
+ @R.function
+ def main(
+ x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10),
dtype="float32")
+ ) -> R.Tensor((10, 10), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.subtract(y, x)
+ gv: R.Tensor((10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected_rsub2:
+ @R.function
+ def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10),
dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") =
R.subtract(R.const(5.0, "float32"), x)
+ gv: R.Tensor((10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(RSub1(), input_info1, {}, expected_rsub1)
+ verify_model(RSub2(), input_info2, {}, expected_rsub2)
+
+
def test_size():
input_info = [([1, 3, 10, 10], "float32")]