swjng opened a new pull request, #19512:
URL: https://github.com/apache/tvm/pull/19512
## Motivation
The PyTorch frontend's `_var` ignored the `correction` kwarg of
`aten.var.correction`. `torch.export.run_decompositions()` rewrites
both `aten.std.correction` and `aten.std.dim` into
`aten.var.correction(..., correction=<value>) → sqrt`, so every
`torch.std`/`torch.var` call lands in `_var` — but the correction
value was dropped on the floor. The variance was therefore always
divided by `n` regardless of what the user requested.
Minimal repro (vs PyTorch eager):
```
x = [[1, 2, 3, 4, 5], [2, 2, 2, 2, 2]]
torch.std(x, dim=1, unbiased=True)
ref: [1.5811, 0.0] # sqrt(2.5)
tvm: [1.4142, 0.0] # sqrt(2.0) -- correction silently set to 0
```
The same omission shows up for explicit `torch.var(x, correction=k)`
and any model that relies on the documented Bessel default.
## Fix
Route `aten.var.correction` (identified by `OpOverload._overloadname`,
not a substring match) to a new `_var_correction` helper. It reads
`correction` from `node.kwargs`, treats `None` as 1 to match the
overload's `Scalar? correction = None` schema, and scales the
existing `relax.op.variance` output by `n / (n - correction)` when
`correction != 0`.
When `n - correction <= 0`, the multiplier is set to NaN rather than
raising — this mirrors PyTorch's documented
`max(0, N - correction)` semantics (eager produces NaN with a warning,
not an error).
Reduction-axis sizes are read from `x.struct_info.shape`. Dynamic
sizes raise `NotImplementedError`; static-shape models cover the
real-world `torch.export` flow.
The legacy fx path through `_var` is intentionally left alone — it has
a separate preexisting bug (it reads `args[2]` as `keepdim` even when
that slot is `unbiased`), but fixing that here would expand the scope
of this PR beyond the `correction` semantics.
## Notes
- `_std` is also registered for `"std.correction"` but is unreachable on
the default exported-program path because `aten.std.*` always
decomposes to `var.correction + sqrt` before dispatch. Sparse-tensor
exports that skip `run_decompositions` still hit the old `_std`; that
path is out of scope for this fix.
- Existing `test_std`/`test_var` encoded the buggy `correction=0` IR
for `torch.var(x)` (which defaults to Bessel) and have been updated
to expect the correct `R.multiply(var, R.const(15/14))`. New
`test_var_correction` covers explicit `correction=2` and `correction=0`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]