This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5a7da7a32a [BugFix][Relax][Torch] Honor `correction` in std/var 
converter (#19512)
5a7da7a32a is described below

commit 5a7da7a32aab0000400e746d93c918e09f80502c
Author: Soowon Jeong <[email protected]>
AuthorDate: Wed May 6 19:51:09 2026 +0900

    [BugFix][Relax][Torch] Honor `correction` in std/var converter (#19512)
    
    ## Motivation
    
    The PyTorch frontend's `_var` ignored the `correction` kwarg of
    `aten.var.correction`. `torch.export.run_decompositions()` rewrites
    both `aten.std.correction` and `aten.std.dim` into
    `aten.var.correction(..., correction=<value>) → sqrt`, so every
    `torch.std`/`torch.var` call lands in `_var` — but the correction
    value was dropped on the floor. The variance was therefore always
    divided by `n` regardless of what the user requested.
    
    Minimal repro (vs PyTorch eager):
    
    ```
    x = [[1, 2, 3, 4, 5], [2, 2, 2, 2, 2]]
    torch.std(x, dim=1, unbiased=True)
      ref: [1.5811, 0.0]   # sqrt(2.5)
      tvm: [1.4142, 0.0]   # sqrt(2.0) -- correction silently set to 0
    ```
    
    The same omission shows up for explicit `torch.var(x, correction=k)`
    and any model that relies on the documented Bessel default.
    
    ## Fix
    
    Route `aten.var.correction` (identified by `OpOverload._overloadname`,
    not a substring match) to a new `_var_correction` helper. It reads
    `correction` from `node.kwargs`, treats `None` as 1 to match the
    overload's `Scalar? correction = None` schema, and scales the
    existing `relax.op.variance` output by `n / (n - correction)` when
    `correction != 0`.
    
    When `n - correction <= 0`, the multiplier is set to NaN rather than
    raising — this mirrors PyTorch's documented
    `max(0, N - correction)` semantics (eager produces NaN with a warning,
    not an error).
    
    Reduction-axis sizes are read from `x.struct_info.shape`. Dynamic
    sizes raise `NotImplementedError`; static-shape models cover the
    real-world `torch.export` flow.
    
    The legacy fx path through `_var` is intentionally left alone — it has
    a separate preexisting bug (it reads `args[2]` as `keepdim` even when
    that slot is `unbiased`), but fixing that here would expand the scope
    of this PR beyond the `correction` semantics.
    
    ## Notes
    
    - `_std` is also registered for `"std.correction"` but is unreachable on
      the default exported-program path because `aten.std.*` always
      decomposes to `var.correction + sqrt` before dispatch. Sparse-tensor
      exports that skip `run_decompositions` still hit the old `_std`; that
      path is out of scope for this fix.
    - Existing `test_std`/`test_var` encoded the buggy `correction=0` IR
      for `torch.var(x)` (which defaults to Bessel) and have been updated
      to expect the correct `R.multiply(var, R.const(15/14))`. New
    `test_var_correction` covers explicit `correction=2` and `correction=0`.
---
 .../frontend/torch/base_fx_graph_translator.py     | 60 ++++++++++++++++++++++
 .../relax/test_frontend_from_exported_program.py   | 49 ++++++++++++++++--
 2 files changed, 106 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 0d92576c59..138176155a 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1645,12 +1645,72 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         return self.block_builder.emit(relax.op.sum(x, dim, keepdims=keepdim))
 
     def _var(self, node: fx.Node) -> relax.Var:
+        # `aten.var.correction` (and decomposed `aten.std.*`) carries an
+        # optional `correction` kwarg whose `None` default means 1 (Bessel).
+        # Legacy fx `tensor.var(...)` calls go through the original path
+        # below to keep this fix narrowly scoped.
+        target = node.target
+        if getattr(target, "_overloadname", None) == "correction" or getattr(
+            target, "overload_name", None
+        ) == "correction":
+            return self._var_correction(node)
         args = self.retrieve_args(node)
         x = args[0]
         dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
         keepdim = args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
         return self.block_builder.emit(relax.op.variance(x, dim, 
keepdims=keepdim))
 
+    def _var_correction(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+        keepdim = node.kwargs.get("keepdim", False)
+        correction = node.kwargs.get("correction", None)
+        if correction is None:
+            correction = 1
+        var = self.block_builder.emit(relax.op.variance(x, dim, 
keepdims=keepdim))
+        if correction == 0:
+            return var
+        n = self._reduction_size(x, dim)
+        if n is None:
+            raise NotImplementedError(
+                "var/std with non-zero correction requires statically known "
+                "reduction-axis sizes."
+            )
+        # PyTorch returns NaN (with a warning) when `n - correction <= 0`;
+        # mirror that semantics rather than failing the import.
+        if n - correction <= 0:
+            scale = float("nan")
+        else:
+            scale = float(n) / float(n - correction)
+        return self.block_builder.emit(
+            relax.op.multiply(var, relax.const(scale, x.struct_info.dtype))
+        )
+
+    @staticmethod
+    def _reduction_size(x: relax.Expr, dim) -> int | None:
+        """Static product of reduced-axis sizes; None if any axis is 
dynamic."""
+        shape = x.struct_info.shape
+        if shape is None:
+            return None
+        rank = len(shape)
+        if dim is None:
+            axes = list(range(rank))
+        elif isinstance(dim, int):
+            axes = [dim]
+        elif isinstance(dim, (list, tuple)) and all(isinstance(a, int) for a 
in dim):
+            axes = list(dim)
+        else:
+            return None
+        n = 1
+        for ax in axes:
+            ax = ax + rank if ax < 0 else ax
+            s = shape[ax]
+            if not isinstance(s, tirx.IntImm):
+                return None
+            n *= int(s.value)
+        return n
+
     def _any(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         x = args[0]
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index d5ed2aca7c..e2f9751c15 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7533,6 +7533,7 @@ def test_any():
 
 
 def test_std():
+    # torch.std(x) defaults to correction=1 (Bessel); decomposes to 
var.correction + sqrt.
     class Std(Module):
         def forward(self, x):
             return torch.std(x)
@@ -7545,8 +7546,9 @@ def test_std():
         ) -> R.Tuple(R.Tensor((), dtype="float32")):
             with R.dataflow():
                 lv: R.Tensor((), dtype="float32") = R.variance(x, axis=None, 
keepdims=False)
-                lv1: R.Tensor((), dtype="float32") = R.sqrt(lv)
-                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,)
+                lv1: R.Tensor((), dtype="float32") = R.multiply(lv, 
R.const(15.0 / 14.0, "float32"))
+                lv2: R.Tensor((), dtype="float32") = R.sqrt(lv1)
+                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv2,)
                 R.output(gv)
             return gv
 
@@ -7555,6 +7557,7 @@ def test_std():
 
 
 def test_var():
+    # torch.var(x) defaults to correction=1 (Bessel).
     class Var(Module):
         def forward(self, x):
             return torch.var(x)
@@ -7567,7 +7570,8 @@ def test_var():
         ) -> R.Tuple(R.Tensor((), dtype="float32")):
             with R.dataflow():
                 lv: R.Tensor((), dtype="float32") = R.variance(x, axis=None, 
keepdims=False)
-                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+                lv1: R.Tensor((), dtype="float32") = R.multiply(lv, 
R.const(15.0 / 14.0, "float32"))
+                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,)
                 R.output(gv)
             return gv
 
@@ -7575,6 +7579,45 @@ def test_var():
     verify_model(Var(), example_args, {}, Expected)
 
 
+def test_var_correction():
+    class VarCorrection2(Module):
+        def forward(self, x):
+            return torch.var(x, dim=-1, correction=2)
+
+    class VarCorrection0(Module):
+        def forward(self, x):
+            return torch.var(x, dim=1, correction=0)
+
+    @tvm.script.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            x: R.Tensor((2, 5), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2,), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2,), dtype="float32") = R.variance(x, axis=[-1], 
keepdims=False)
+                lv1: R.Tensor((2,), dtype="float32") = R.multiply(lv, 
R.const(5.0 / 3.0, "float32"))
+                gv: R.Tuple(R.Tensor((2,), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected0:
+        @R.function
+        def main(
+            x: R.Tensor((2, 5), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2,), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2,), dtype="float32") = R.variance(x, axis=[1], 
keepdims=False)
+                gv: R.Tuple(R.Tensor((2,), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(2, 5, dtype=torch.float32),)
+    verify_model(VarCorrection2(), example_args, {}, Expected2)
+    verify_model(VarCorrection0(), example_args, {}, Expected0)
+
+
 def test_prod():
     class Prod(Module):
         def forward(self, x):

Reply via email to