This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new aafb0db251 [Relax][PyTorch] Add RSub Op Support for Exported Program 
and FX graph (#17849)
aafb0db251 is described below

commit aafb0db251786145b6592a4cc8ca2ca47007c44d
Author: Deivanayaki S <[email protected]>
AuthorDate: Fri Apr 18 07:27:55 2025 +0530

    [Relax][PyTorch] Add RSub Op Support for Exported Program and FX graph 
(#17849)
    
    * add rsub op support into exported and fx graph frontend
    
    * fix trailing whitespace issue
    
    * fix lint issues in test scripts
    
    ---------
    
    Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
 .../frontend/torch/base_fx_graph_translator.py     | 10 ++++++
 .../frontend/torch/exported_program_translator.py  |  2 ++
 python/tvm/relax/frontend/torch/fx_translator.py   |  1 +
 .../relax/test_frontend_from_exported_program.py   | 37 ++++++++++++++++++++
 tests/python/relax/test_frontend_from_fx.py        | 39 ++++++++++++++++++++++
 5 files changed, 89 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 7b380f9876..a9bee11fc8 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -407,6 +407,16 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
         return convert
 
+    def _rsub(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        lhs = args[0]
+        rhs = args[1]
+
+        if isinstance(rhs, (int, float)):
+            rhs = relax.const(rhs)
+
+        return self.block_builder.emit(relax.op.subtract(rhs, lhs))
+
     ########## Linear Algebra ##########
 
     def _linalg_vector_norm(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index a6f9cafa65..4084e35de5 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -305,6 +305,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "relu_.default": self._unary_op(relax.op.nn.relu),
             "round.default": self._round,
             "rsqrt.default": self._unary_op(relax.op.rsqrt),
+            "rsub.Tensor": self._rsub,
+            "rsub.Scalar": self._rsub,
             "selu.default": self._unary_op(relax.op.nn.selu),
             "sigmoid.default": self._unary_op(relax.op.sigmoid),
             "sign.default": self._unary_op(relax.op.sign),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 534d398bea..4ef0b05aca 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -692,6 +692,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "pow": self._binary_op(relax.op.power, operator.pow),
             "or_": self._binary_op(relax.op.bitwise_or, operator.or_),
             "rshift": self._binary_op(relax.op.right_shift, operator.rshift),
+            "rsub": self._rsub,
             "sub": self._binary_op(relax.op.subtract, operator.sub),
             "truediv": self._binary_op(relax.op.divide, operator.truediv),
             "xor": self._binary_op(relax.op.bitwise_xor, operator.xor),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index e78bd339d2..7c47832ea9 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -935,6 +935,7 @@ def test_binary3():
         torch.randn(10, 10, dtype=torch.float32),
         torch.randn(10, 10, dtype=torch.float32),
     )
+    example_args2 = (torch.randn(10, 10, dtype=torch.float32),)
 
     # Max
     class Max1(Module):
@@ -976,6 +977,42 @@ def test_binary3():
 
     verify_model(Min1(), example_args1, {}, expected_min1)
 
+    # RSub
+    class RSub1(Module):
+        def forward(self, x, y):
+            return torch.rsub(x, y)
+
+    class RSub2(Module):
+        def forward(self, x):
+            return torch.rsub(x, 5.0)
+
+    @tvm.script.ir_module
+    class expected_rsub1:
+        @R.function
+        def main(
+            x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), 
dtype="float32")
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.subtract(y, x)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected_rsub2:
+        @R.function
+        def main(
+            x: R.Tensor((10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = 
R.subtract(R.const(5.0, "float32"), x)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(RSub1(), example_args1, {}, expected_rsub1)
+    verify_model(RSub2(), example_args2, {}, expected_rsub2)
+
 
 def test_batchnorm2d():
     class BatchNorm2d(Module):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index caecce4979..a2169afd0f 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1738,6 +1738,45 @@ def test_binary3(op, relax_op):
     verify_model(Binary2(op), input_info2, {}, expected_binary2)
 
 
+# RSub
+def test_rsub():
+    input_info1 = [([10, 10], "float32"), ([10, 10], "float32")]
+    input_info2 = [([10, 10], "float32")]
+
+    class RSub1(Module):
+        def forward(self, x, y):
+            return torch.rsub(x, y)
+
+    class RSub2(Module):
+        def forward(self, x):
+            return torch.rsub(x, 5.0)
+
+    @tvm.script.ir_module
+    class expected_rsub1:
+        @R.function
+        def main(
+            x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), 
dtype="float32")
+        ) -> R.Tensor((10, 10), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.subtract(y, x)
+                gv: R.Tensor((10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected_rsub2:
+        @R.function
+        def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), 
dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = 
R.subtract(R.const(5.0, "float32"), x)
+                gv: R.Tensor((10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(RSub1(), input_info1, {}, expected_rsub1)
+    verify_model(RSub2(), input_info2, {}, expected_rsub2)
+
+
 def test_size():
     input_info = [([1, 3, 10, 10], "float32")]
 

Reply via email to