This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 91c1921210 [Relax][PyTorch] Add binary operation dtype promotion 
following PyTorch rules in ExportedProgram frontend (#18497)
91c1921210 is described below

commit 91c1921210adb5a911ee133ca35b46cdea472843
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Mon Nov 24 17:18:19 2025 +0900

    [Relax][PyTorch] Add binary operation dtype promotion following PyTorch 
rules in ExportedProgram frontend (#18497)
    
    As per title.
    ref:
    https://docs.pytorch.org/docs/stable/generated/torch.promote_types.html
---
 .../frontend/torch/base_fx_graph_translator.py     | 41 +++++++++++++++
 .../relax/test_frontend_from_exported_program.py   | 61 ++++++++++++++++++++++
 2 files changed, 102 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index fb8790322e..2b97f22c92 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -88,6 +88,36 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             return tensor.shape
         raise ValueError("Unsupported type: {}".format(type(tensor)))
 
+    @staticmethod
+    def _promote_common_dtype(lhs_dtype: Optional[str], rhs_dtype: 
Optional[str]) -> Optional[str]:
+        """Return the promoted dtype following PyTorch rules, or None if 
unsupported."""
+        import torch  # type: ignore
+
+        if lhs_dtype is None or rhs_dtype is None or lhs_dtype == rhs_dtype:
+            return None
+
+        tvm_to_torch = {
+            "float64": torch.float64,
+            "float32": torch.float32,
+            "float16": torch.float16,
+            "bfloat16": torch.bfloat16,
+            "int64": torch.int64,
+            "int32": torch.int32,
+            "int16": torch.int16,
+            "int8": torch.int8,
+            "uint8": torch.uint8,
+            "bool": torch.bool,
+        }
+        torch_to_tvm = {v: k for k, v in tvm_to_torch.items()}
+
+        lhs_torch = tvm_to_torch.get(lhs_dtype)
+        rhs_torch = tvm_to_torch.get(rhs_dtype)
+        if lhs_torch is None or rhs_torch is None:
+            return None
+
+        promoted = torch.promote_types(lhs_torch, rhs_torch)
+        return torch_to_tvm.get(promoted, None)
+
     @staticmethod
     def _is_no_bias(bias):
         """Check if bias represents 'no bias' condition.
@@ -408,6 +438,17 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         def convert(node: fx.Node) -> relax.Var:
             def promote_binary_op_args(lhs, rhs):
                 if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
+                    lhs_si = getattr(lhs, "struct_info", None)
+                    rhs_si = getattr(rhs, "struct_info", None)
+                    if isinstance(lhs_si, relax.TensorStructInfo) and 
isinstance(
+                        rhs_si, relax.TensorStructInfo
+                    ):
+                        target_dtype = 
self._promote_common_dtype(lhs_si.dtype, rhs_si.dtype)
+                        if target_dtype is not None:
+                            if lhs_si.dtype != target_dtype:
+                                lhs = 
self.block_builder.emit(relax.op.astype(lhs, target_dtype))
+                            if rhs_si.dtype != target_dtype:
+                                rhs = 
self.block_builder.emit(relax.op.astype(rhs, target_dtype))
                     return lhs, rhs
                 elif isinstance(lhs, relax.Expr):
                     assert isinstance(lhs.struct_info, relax.TensorStructInfo)
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 89017e30a7..78a8a09a3c 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1383,6 +1383,67 @@ def test_binary1(op, relax_op):
     verify_model(Binary2(op), example_args2, {}, expected2)
 
 
+operator_binary_promote = [
+    (operator.add, R.add),
+    (operator.sub, R.subtract),
+    (operator.mul, R.multiply),
+    (operator.truediv, R.divide),
+    (operator.pow, R.power),
+    (operator.mod, R.floor_mod),
+]
+
+
[email protected]("op, relax_op", operator_binary_promote)
+def test_binary_dtype_promotion(op, relax_op):
+    """Ensure binary ops promote differing dtypes following PyTorch rules."""
+
+    class BinaryPromoteLHS(Module):
+        def forward(self, x):
+            arange_val = torch.arange(x.shape[1])  # int64 by default
+            return op(x, arange_val)
+
+    @tvm.script.ir_module
+    class expected_promote_lhs:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3), dtype="float32")
+        ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((3,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(3), R.prim_value(1), 
dtype="int64"
+                )
+                lv1: R.Tensor((3,), dtype="float32") = R.astype(lv, 
dtype="float32")
+                lv2: R.Tensor((2, 3), dtype="float32") = relax_op(x, lv1)
+                gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv2,)
+                R.output(gv)
+            return gv
+
+    class BinaryPromoteRHS(Module):
+        def forward(self, x):
+            arange_val = torch.arange(x.shape[1])  # int64 by default
+            return op(arange_val, x)
+
+    @tvm.script.ir_module
+    class expected_promote_rhs:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3), dtype="float32")
+        ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((3,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(3), R.prim_value(1), 
dtype="int64"
+                )
+                lv1: R.Tensor((3,), dtype="float32") = R.astype(lv, 
dtype="float32")
+                lv2: R.Tensor((2, 3), dtype="float32") = relax_op(lv1, x)
+                gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv2,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(2, 3, dtype=torch.float32),)
+    verify_model(BinaryPromoteLHS(), example_args, {}, expected_promote_lhs)
+    verify_model(BinaryPromoteRHS(), example_args, {}, expected_promote_rhs)
+
+
 operator_binary_2 = [
     (operator.eq, R.equal),
     (operator.ne, R.not_equal),

Reply via email to