This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9545b3c1a4 [Relax][PyTorch] Support specifying decimals for _round
(#18507)
9545b3c1a4 is described below
commit 9545b3c1a47d38ab2aab5d8c2aa9bc833672ed85
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Thu Nov 27 13:04:48 2025 +0800
[Relax][PyTorch] Support specifying decimals for _round (#18507)
## Why
- The current `round` function does not support specifying the number of
decimal places.
## How
- Allows rounding to a specified number of decimals
- Add tests for `_round`
---
.../frontend/torch/base_fx_graph_translator.py | 15 ++++-
tests/python/relax/test_frontend_from_fx.py | 75 ++++++++++++++++++++++
2 files changed, 87 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 2b97f22c92..f70032bc7f 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -364,10 +364,19 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return self.block_builder.emit(relax.op.nn.prelu(x, alpha, axis))
def _round(self, node: fx.Node) -> relax.Expr:
- if node.kwargs.get("decimals", 0) != 0:
- raise ValueError("specifying decimals for round is not supported
yet")
arg = self.env[node.args[0]]
- return self.block_builder.emit(relax.op.round(arg))
+ decimals = node.kwargs.get("decimals", 0)
+
+ if decimals == 0:
+ return self.block_builder.emit(relax.op.round(arg))
+
+ # For decimals != 0, use: round(x * 10^decimals) / 10^decimals
+ dtype = arg.struct_info.dtype
+ scale = relax.const(10**decimals, dtype)
+ scaled = relax.op.multiply(arg, scale)
+ rounded = relax.op.round(scaled)
+ result = relax.op.divide(rounded, scale)
+ return self.block_builder.emit(result)
def _softmax(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 9840665251..b1571ef388 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -6273,5 +6273,80 @@ def test_linspace():
)
+def test_round():
+ input_info = [([3, 4], "float32")]
+
+ class Round(Module):
+ def __init__(self, decimals=0):
+ super().__init__()
+ self.decimals = decimals
+
+ def forward(self, x):
+ if self.decimals == 0:
+ return torch.round(x)
+ else:
+ return torch.round(x, decimals=self.decimals)
+
+ @tvm.script.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((3, 4), dtype="float32"),
+ ) -> R.Tensor((3, 4), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((3, 4), dtype="float32") = R.round(inp_0)
+ gv: R.Tensor((3, 4), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((3, 4), dtype="float32"),
+ ) -> R.Tensor((3, 4), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((3, 4), dtype="float32") = R.multiply(inp_0,
R.const(100.0, "float32"))
+ lv1: R.Tensor((3, 4), dtype="float32") = R.round(lv)
+ lv2: R.Tensor((3, 4), dtype="float32") = R.divide(lv1,
R.const(100.0, "float32"))
+ gv: R.Tensor((3, 4), dtype="float32") = lv2
+ R.output(gv)
+ return gv
+
+ rounds = [
+ (0, Expected1),
+ (2, Expected2),
+ ]
+
+ for decimals, expected in rounds:
+ verify_model(Round(decimals), input_info, {}, expected)
+
+ # Test numerical accuracy with decimals
+ test_data = torch.tensor(
+ [
+ [1.2345, 2.3456, 3.4567, 4.5678],
+ [5.6789, 6.7890, 7.8901, 8.9012],
+ [9.1234, 10.2345, 11.3456, 12.4567],
+ ]
+ )
+
+ for decimals in [0, 1, 2, 3]:
+ torch_model = Round(decimals)
+ graph_model = fx.symbolic_trace(torch_model)
+ with torch.no_grad():
+ mod = from_fx(graph_model, input_info)
+
+ target = tvm.target.Target("llvm")
+ ex = relax.build(mod, target)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ torch_result = torch_model(test_data).numpy()
+ tvm_input = tvm.runtime.tensor(test_data.numpy())
+ tvm_result = vm["main"](tvm_input).numpy()
+
+ # Use relaxed tolerance due to floating-point precision in decimal
operations
+ tvm.testing.assert_allclose(tvm_result, torch_result, rtol=1e-3,
atol=1e-3)
+
+
if __name__ == "__main__":
tvm.testing.main()