This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5a7da7a32a [BugFix][Relax][Torch] Honor `correction` in std/var
converter (#19512)
5a7da7a32a is described below
commit 5a7da7a32aab0000400e746d93c918e09f80502c
Author: Soowon Jeong <[email protected]>
AuthorDate: Wed May 6 19:51:09 2026 +0900
[BugFix][Relax][Torch] Honor `correction` in std/var converter (#19512)
## Motivation
The PyTorch frontend's `_var` ignored the `correction` kwarg of
`aten.var.correction`. `torch.export.run_decompositions()` rewrites
both `aten.std.correction` and `aten.std.dim` into
`aten.var.correction(..., correction=<value>) → sqrt`, so every
`torch.std`/`torch.var` call lands in `_var` — but the correction
value was dropped on the floor. The variance was therefore always
divided by `n` regardless of what the user requested.
Minimal repro (vs PyTorch eager):
```
x = [[1, 2, 3, 4, 5], [2, 2, 2, 2, 2]]
torch.std(x, dim=1, unbiased=True)
ref: [1.5811, 0.0] # sqrt(2.5)
tvm: [1.4142, 0.0] # sqrt(2.0) -- correction silently set to 0
```
The same omission shows up for explicit `torch.var(x, correction=k)`
and any model that relies on the documented Bessel default.
## Fix
Route `aten.var.correction` (identified by `OpOverload._overloadname`,
not a substring match) to a new `_var_correction` helper. It reads
`correction` from `node.kwargs`, treats `None` as 1 to match the
overload's `Scalar? correction = None` schema, and scales the
existing `relax.op.variance` output by `n / (n - correction)` when
`correction != 0`.
When `n - correction <= 0`, the multiplier is set to NaN rather than
raising — this mirrors PyTorch's documented
`max(0, N - correction)` semantics (eager produces NaN with a warning,
not an error).
Reduction-axis sizes are read from `x.struct_info.shape`. Dynamic
sizes raise `NotImplementedError`; static-shape models cover the
real-world `torch.export` flow.
The legacy fx path through `_var` is intentionally left alone — it has
a separate preexisting bug (it reads `args[2]` as `keepdim` even when
that slot is `unbiased`), but fixing that here would expand the scope
of this PR beyond the `correction` semantics.
## Notes
- `_std` is also registered for `"std.correction"` but is unreachable on
the default exported-program path because `aten.std.*` always
decomposes to `var.correction + sqrt` before dispatch. Sparse-tensor
exports that skip `run_decompositions` still hit the old `_std`; that
path is out of scope for this fix.
- Existing `test_std`/`test_var` encoded the buggy `correction=0` IR
for `torch.var(x)` (which defaults to Bessel) and have been updated
to expect the correct `R.multiply(var, R.const(15/14))`. New
`test_var_correction` covers explicit `correction=2` and `correction=0`.
---
.../frontend/torch/base_fx_graph_translator.py | 60 ++++++++++++++++++++++
.../relax/test_frontend_from_exported_program.py | 49 ++++++++++++++++--
2 files changed, 106 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 0d92576c59..138176155a 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1645,12 +1645,72 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return self.block_builder.emit(relax.op.sum(x, dim, keepdims=keepdim))
def _var(self, node: fx.Node) -> relax.Var:
+ # `aten.var.correction` (and decomposed `aten.std.*`) carries an
+ # optional `correction` kwarg whose `None` default means 1 (Bessel).
+ # Legacy fx `tensor.var(...)` calls go through the original path
+ # below to keep this fix narrowly scoped.
+ target = node.target
+ if getattr(target, "_overloadname", None) == "correction" or getattr(
+ target, "overload_name", None
+ ) == "correction":
+ return self._var_correction(node)
args = self.retrieve_args(node)
x = args[0]
dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
keepdim = args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
return self.block_builder.emit(relax.op.variance(x, dim,
keepdims=keepdim))
+ def _var_correction(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+ keepdim = node.kwargs.get("keepdim", False)
+ correction = node.kwargs.get("correction", None)
+ if correction is None:
+ correction = 1
+ var = self.block_builder.emit(relax.op.variance(x, dim,
keepdims=keepdim))
+ if correction == 0:
+ return var
+ n = self._reduction_size(x, dim)
+ if n is None:
+ raise NotImplementedError(
+ "var/std with non-zero correction requires statically known "
+ "reduction-axis sizes."
+ )
+ # PyTorch returns NaN (with a warning) when `n - correction <= 0`;
+ # mirror that semantics rather than failing the import.
+ if n - correction <= 0:
+ scale = float("nan")
+ else:
+ scale = float(n) / float(n - correction)
+ return self.block_builder.emit(
+ relax.op.multiply(var, relax.const(scale, x.struct_info.dtype))
+ )
+
+ @staticmethod
+ def _reduction_size(x: relax.Expr, dim) -> int | None:
+ """Static product of reduced-axis sizes; None if any axis is
dynamic."""
+ shape = x.struct_info.shape
+ if shape is None:
+ return None
+ rank = len(shape)
+ if dim is None:
+ axes = list(range(rank))
+ elif isinstance(dim, int):
+ axes = [dim]
+ elif isinstance(dim, (list, tuple)) and all(isinstance(a, int) for a
in dim):
+ axes = list(dim)
+ else:
+ return None
+ n = 1
+ for ax in axes:
+ ax = ax + rank if ax < 0 else ax
+ s = shape[ax]
+ if not isinstance(s, tirx.IntImm):
+ return None
+ n *= int(s.value)
+ return n
+
def _any(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index d5ed2aca7c..e2f9751c15 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7533,6 +7533,7 @@ def test_any():
def test_std():
+ # torch.std(x) defaults to correction=1 (Bessel); decomposes to
var.correction + sqrt.
class Std(Module):
def forward(self, x):
return torch.std(x)
@@ -7545,8 +7546,9 @@ def test_std():
) -> R.Tuple(R.Tensor((), dtype="float32")):
with R.dataflow():
lv: R.Tensor((), dtype="float32") = R.variance(x, axis=None,
keepdims=False)
- lv1: R.Tensor((), dtype="float32") = R.sqrt(lv)
- gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,)
+ lv1: R.Tensor((), dtype="float32") = R.multiply(lv,
R.const(15.0 / 14.0, "float32"))
+ lv2: R.Tensor((), dtype="float32") = R.sqrt(lv1)
+ gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv2,)
R.output(gv)
return gv
@@ -7555,6 +7557,7 @@ def test_std():
def test_var():
+ # torch.var(x) defaults to correction=1 (Bessel).
class Var(Module):
def forward(self, x):
return torch.var(x)
@@ -7567,7 +7570,8 @@ def test_var():
) -> R.Tuple(R.Tensor((), dtype="float32")):
with R.dataflow():
lv: R.Tensor((), dtype="float32") = R.variance(x, axis=None,
keepdims=False)
- gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+ lv1: R.Tensor((), dtype="float32") = R.multiply(lv,
R.const(15.0 / 14.0, "float32"))
+ gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,)
R.output(gv)
return gv
@@ -7575,6 +7579,45 @@ def test_var():
verify_model(Var(), example_args, {}, Expected)
+def test_var_correction():
+ class VarCorrection2(Module):
+ def forward(self, x):
+ return torch.var(x, dim=-1, correction=2)
+
+ class VarCorrection0(Module):
+ def forward(self, x):
+ return torch.var(x, dim=1, correction=0)
+
+ @tvm.script.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ x: R.Tensor((2, 5), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((2,), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2,), dtype="float32") = R.variance(x, axis=[-1],
keepdims=False)
+ lv1: R.Tensor((2,), dtype="float32") = R.multiply(lv,
R.const(5.0 / 3.0, "float32"))
+ gv: R.Tuple(R.Tensor((2,), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected0:
+ @R.function
+ def main(
+ x: R.Tensor((2, 5), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((2,), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2,), dtype="float32") = R.variance(x, axis=[1],
keepdims=False)
+ gv: R.Tuple(R.Tensor((2,), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(2, 5, dtype=torch.float32),)
+ verify_model(VarCorrection2(), example_args, {}, Expected2)
+ verify_model(VarCorrection0(), example_args, {}, Expected0)
+
+
def test_prod():
class Prod(Module):
def forward(self, x):