This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 91c1921210 [Relax][PyTorch] Add binary operation dtype promotion
following PyTorch rules in ExportedProgram frontend (#18497)
91c1921210 is described below
commit 91c1921210adb5a911ee133ca35b46cdea472843
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Mon Nov 24 17:18:19 2025 +0900
[Relax][PyTorch] Add binary operation dtype promotion following PyTorch
rules in ExportedProgram frontend (#18497)
As per title.
ref:
https://docs.pytorch.org/docs/stable/generated/torch.promote_types.html
---
.../frontend/torch/base_fx_graph_translator.py | 41 +++++++++++++++
.../relax/test_frontend_from_exported_program.py | 61 ++++++++++++++++++++++
2 files changed, 102 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index fb8790322e..2b97f22c92 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -88,6 +88,36 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return tensor.shape
raise ValueError("Unsupported type: {}".format(type(tensor)))
+ @staticmethod
+ def _promote_common_dtype(lhs_dtype: Optional[str], rhs_dtype:
Optional[str]) -> Optional[str]:
+ """Return the promoted dtype following PyTorch rules, or None if
unsupported."""
+ import torch # type: ignore
+
+ if lhs_dtype is None or rhs_dtype is None or lhs_dtype == rhs_dtype:
+ return None
+
+ tvm_to_torch = {
+ "float64": torch.float64,
+ "float32": torch.float32,
+ "float16": torch.float16,
+ "bfloat16": torch.bfloat16,
+ "int64": torch.int64,
+ "int32": torch.int32,
+ "int16": torch.int16,
+ "int8": torch.int8,
+ "uint8": torch.uint8,
+ "bool": torch.bool,
+ }
+ torch_to_tvm = {v: k for k, v in tvm_to_torch.items()}
+
+ lhs_torch = tvm_to_torch.get(lhs_dtype)
+ rhs_torch = tvm_to_torch.get(rhs_dtype)
+ if lhs_torch is None or rhs_torch is None:
+ return None
+
+ promoted = torch.promote_types(lhs_torch, rhs_torch)
+ return torch_to_tvm.get(promoted, None)
+
@staticmethod
def _is_no_bias(bias):
"""Check if bias represents 'no bias' condition.
@@ -408,6 +438,17 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
def convert(node: fx.Node) -> relax.Var:
def promote_binary_op_args(lhs, rhs):
if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
+ lhs_si = getattr(lhs, "struct_info", None)
+ rhs_si = getattr(rhs, "struct_info", None)
+ if isinstance(lhs_si, relax.TensorStructInfo) and
isinstance(
+ rhs_si, relax.TensorStructInfo
+ ):
+ target_dtype =
self._promote_common_dtype(lhs_si.dtype, rhs_si.dtype)
+ if target_dtype is not None:
+ if lhs_si.dtype != target_dtype:
+ lhs =
self.block_builder.emit(relax.op.astype(lhs, target_dtype))
+ if rhs_si.dtype != target_dtype:
+ rhs =
self.block_builder.emit(relax.op.astype(rhs, target_dtype))
return lhs, rhs
elif isinstance(lhs, relax.Expr):
assert isinstance(lhs.struct_info, relax.TensorStructInfo)
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 89017e30a7..78a8a09a3c 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1383,6 +1383,67 @@ def test_binary1(op, relax_op):
verify_model(Binary2(op), example_args2, {}, expected2)
+operator_binary_promote = [
+ (operator.add, R.add),
+ (operator.sub, R.subtract),
+ (operator.mul, R.multiply),
+ (operator.truediv, R.divide),
+ (operator.pow, R.power),
+ (operator.mod, R.floor_mod),
+]
+
+
[email protected]("op, relax_op", operator_binary_promote)
+def test_binary_dtype_promotion(op, relax_op):
+ """Ensure binary ops promote differing dtypes following PyTorch rules."""
+
+ class BinaryPromoteLHS(Module):
+ def forward(self, x):
+ arange_val = torch.arange(x.shape[1]) # int64 by default
+ return op(x, arange_val)
+
+ @tvm.script.ir_module
+ class expected_promote_lhs:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3), dtype="float32")
+ ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((3,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(3), R.prim_value(1),
dtype="int64"
+ )
+ lv1: R.Tensor((3,), dtype="float32") = R.astype(lv,
dtype="float32")
+ lv2: R.Tensor((2, 3), dtype="float32") = relax_op(x, lv1)
+ gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv2,)
+ R.output(gv)
+ return gv
+
+ class BinaryPromoteRHS(Module):
+ def forward(self, x):
+ arange_val = torch.arange(x.shape[1]) # int64 by default
+ return op(arange_val, x)
+
+ @tvm.script.ir_module
+ class expected_promote_rhs:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3), dtype="float32")
+ ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((3,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(3), R.prim_value(1),
dtype="int64"
+ )
+ lv1: R.Tensor((3,), dtype="float32") = R.astype(lv,
dtype="float32")
+ lv2: R.Tensor((2, 3), dtype="float32") = relax_op(lv1, x)
+ gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv2,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(2, 3, dtype=torch.float32),)
+ verify_model(BinaryPromoteLHS(), example_args, {}, expected_promote_lhs)
+ verify_model(BinaryPromoteRHS(), example_args, {}, expected_promote_rhs)
+
+
operator_binary_2 = [
(operator.eq, R.equal),
(operator.ne, R.not_equal),