This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9545b3c1a4 [Relax][PyTorch] Support specifying decimals for _round 
(#18507)
9545b3c1a4 is described below

commit 9545b3c1a47d38ab2aab5d8c2aa9bc833672ed85
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Thu Nov 27 13:04:48 2025 +0800

    [Relax][PyTorch] Support specifying decimals for _round (#18507)
    
    ## Why
    
    - The current `round` function does not support specifying the number of
    decimal places.
    
    ## How
    
    - Allows rounding to a specified number of decimals
    - Add tests for `_round`
---
 .../frontend/torch/base_fx_graph_translator.py     | 15 ++++-
 tests/python/relax/test_frontend_from_fx.py        | 75 ++++++++++++++++++++++
 2 files changed, 87 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 2b97f22c92..f70032bc7f 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -364,10 +364,19 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         return self.block_builder.emit(relax.op.nn.prelu(x, alpha, axis))
 
     def _round(self, node: fx.Node) -> relax.Expr:
-        if node.kwargs.get("decimals", 0) != 0:
-            raise ValueError("specifying decimals for round is not supported 
yet")
         arg = self.env[node.args[0]]
-        return self.block_builder.emit(relax.op.round(arg))
+        decimals = node.kwargs.get("decimals", 0)
+
+        if decimals == 0:
+            return self.block_builder.emit(relax.op.round(arg))
+
+        # For decimals != 0, use: round(x * 10^decimals) / 10^decimals
+        dtype = arg.struct_info.dtype
+        scale = relax.const(10**decimals, dtype)
+        scaled = relax.op.multiply(arg, scale)
+        rounded = relax.op.round(scaled)
+        result = relax.op.divide(rounded, scale)
+        return self.block_builder.emit(result)
 
     def _softmax(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 9840665251..b1571ef388 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -6273,5 +6273,80 @@ def test_linspace():
     )
 
 
+def test_round():
+    input_info = [([3, 4], "float32")]
+
+    class Round(Module):
+        def __init__(self, decimals=0):
+            super().__init__()
+            self.decimals = decimals
+
+        def forward(self, x):
+            if self.decimals == 0:
+                return torch.round(x)
+            else:
+                return torch.round(x, decimals=self.decimals)
+
+    @tvm.script.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((3, 4), dtype="float32"),
+        ) -> R.Tensor((3, 4), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((3, 4), dtype="float32") = R.round(inp_0)
+                gv: R.Tensor((3, 4), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((3, 4), dtype="float32"),
+        ) -> R.Tensor((3, 4), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((3, 4), dtype="float32") = R.multiply(inp_0, 
R.const(100.0, "float32"))
+                lv1: R.Tensor((3, 4), dtype="float32") = R.round(lv)
+                lv2: R.Tensor((3, 4), dtype="float32") = R.divide(lv1, 
R.const(100.0, "float32"))
+                gv: R.Tensor((3, 4), dtype="float32") = lv2
+                R.output(gv)
+            return gv
+
+    rounds = [
+        (0, Expected1),
+        (2, Expected2),
+    ]
+
+    for decimals, expected in rounds:
+        verify_model(Round(decimals), input_info, {}, expected)
+
+    # Test numerical accuracy with decimals
+    test_data = torch.tensor(
+        [
+            [1.2345, 2.3456, 3.4567, 4.5678],
+            [5.6789, 6.7890, 7.8901, 8.9012],
+            [9.1234, 10.2345, 11.3456, 12.4567],
+        ]
+    )
+
+    for decimals in [0, 1, 2, 3]:
+        torch_model = Round(decimals)
+        graph_model = fx.symbolic_trace(torch_model)
+        with torch.no_grad():
+            mod = from_fx(graph_model, input_info)
+
+        target = tvm.target.Target("llvm")
+        ex = relax.build(mod, target)
+        vm = relax.VirtualMachine(ex, tvm.cpu())
+
+        torch_result = torch_model(test_data).numpy()
+        tvm_input = tvm.runtime.tensor(test_data.numpy())
+        tvm_result = vm["main"](tvm_input).numpy()
+
+        # Use relaxed tolerance due to floating-point precision in decimal 
operations
+        tvm.testing.assert_allclose(tvm_result, torch_result, rtol=1e-3, 
atol=1e-3)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to