This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ca19be8be8 [Relax][PyTorch] Add support for binary scalar operations 
in ExportedProgram frontend and corresponding tests (#18529)
ca19be8be8 is described below

commit ca19be8be860f796baf70468ccfa378dff681df0
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sun Nov 30 01:19:47 2025 +0900

    [Relax][PyTorch] Add support for binary scalar operations in 
ExportedProgram frontend and corresponding tests (#18529)
    
    Added `add.Scalar` and `sub.Scalar` converter and tests for binary
    scalar ops.
---
 .../frontend/torch/exported_program_translator.py  |  2 ++
 .../relax/test_frontend_from_exported_program.py   | 39 ++++++++++++++++++++++
 2 files changed, 41 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index fc0ca18209..3a33a58f8c 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1253,6 +1253,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "trunc.default": self._unary_op(relax.op.trunc),
             # binary
             "add.Tensor": self._binary_op(relax.op.add, operator.add),
+            "add.Scalar": self._binary_op(relax.op.add, operator.add),
             "add_.Tensor": self._binary_op(relax.op.add, operator.add),
             "bitwise_and.Tensor": self._binary_op(relax.op.bitwise_and, 
operator.and_),
             "bitwise_and.Scalar": self._binary_op(relax.op.bitwise_and, 
operator.and_),
@@ -1306,6 +1307,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow),
             "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow),
             "sub.Tensor": self._binary_op(relax.op.subtract, operator.sub),
+            "sub.Scalar": self._binary_op(relax.op.subtract, operator.sub),
             "__and__.Tensor": self._binary_op(relax.op.bitwise_and, 
operator.and_),
             "__and__.Scalar": self._binary_op(relax.op.bitwise_and, 
operator.and_),
             "__or__.Tensor": self._binary_op(relax.op.bitwise_or, 
operator.or_),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 091f0a4a29..48ca5f3209 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1429,6 +1429,45 @@ def test_binary1(op, relax_op):
     verify_model(Binary2(op), example_args2, {}, expected2)
 
 
+operator_binary_scalar = [
+    (torch.ops.aten.add.Scalar, R.add),
+    (torch.ops.aten.bitwise_and.Scalar, R.bitwise_and),
+    (torch.ops.aten.bitwise_or.Scalar, R.bitwise_or),
+    (torch.ops.aten.bitwise_xor.Scalar, R.bitwise_xor),
+    (torch.ops.aten.div.Scalar, R.divide),
+    (torch.ops.aten.sub.Scalar, R.subtract),
+    (torch.ops.aten.mul.Scalar, R.multiply),
+    (torch.ops.aten.remainder.Scalar, R.floor_mod),
+]
+
+
[email protected]("op, relax_op", operator_binary_scalar)
+def test_binary_scalar(op, relax_op):
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+    class BinaryScalar(Module):
+        def __init__(self, op):
+            super().__init__()
+            self.op = op
+
+        def forward(self, lhs):
+            return self.op(lhs, 1.0)
+
+    @tvm.script.ir_module
+    class expected_binary_scalar:
+        @R.function
+        def main(
+            lhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = relax_op(lhs, 
R.const(1.0))
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(BinaryScalar(op), example_args, {}, expected_binary_scalar)
+
+
 operator_binary_promote = [
     (operator.add, R.add),
     (operator.sub, R.subtract),

Reply via email to