gemini-code-assist[bot] commented on code in PR #19512:
URL: https://github.com/apache/tvm/pull/19512#discussion_r3194102120
##########
python/tvm/relax/frontend/torch/base_fx_graph_translator.py:
##########
@@ -1645,12 +1645,70 @@ def _sum(self, node: fx.Node) -> relax.Var:
return self.block_builder.emit(relax.op.sum(x, dim, keepdims=keepdim))
def _var(self, node: fx.Node) -> relax.Var:
+ # `aten.var.correction` (and decomposed `aten.std.*`) carries an
+ # optional `correction` kwarg whose `None` default means 1 (Bessel).
+ # Legacy fx `tensor.var(...)` calls go through the original path
+ # below to keep this fix narrowly scoped.
+ target = node.target
+ if getattr(target, "_overloadname", None) == "correction" or getattr(
+ target, "overload_name", None
+ ) == "correction":
+ return self._var_correction(node)
args = self.retrieve_args(node)
x = args[0]
dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
keepdim = args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
return self.block_builder.emit(relax.op.variance(x, dim,
keepdims=keepdim))
+ def _var_correction(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+ keepdim = node.kwargs.get("keepdim", False)
+ correction = node.kwargs.get("correction", None)
+ if correction is None:
+ correction = 1
+ var = self.block_builder.emit(relax.op.variance(x, dim,
keepdims=keepdim))
+ if correction == 0:
+ return var
+ n = self._reduction_size(x, dim)
+ if n is None:
+ raise NotImplementedError(
+ "var/std with non-zero correction requires statically known "
+ "reduction-axis sizes."
+ )
+ # PyTorch returns NaN (with a warning) when `n - correction <= 0`;
+ # mirror that semantics rather than failing the import.
+ if n - correction <= 0:
+ scale = float("nan")
+ else:
+ scale = float(n) / float(n - correction)
+ return self.block_builder.emit(
+ relax.op.multiply(var, relax.const(scale, x.struct_info.dtype))
+ )
+
+ @staticmethod
+ def _reduction_size(x: relax.Expr, dim) -> int | None:
+ """Static product of reduced-axis sizes; None if any axis is
dynamic."""
+ shape = x.struct_info.shape
+ if shape is None:
+ return None
+ rank = len(shape)
+ if dim is None:
+ axes = list(range(rank))
+ elif isinstance(dim, int):
+ axes = [dim]
+ else:
+ axes = list(dim)
+ n = 1
+ for ax in axes:
+ ax = ax + rank if ax < 0 else ax
+ s = shape[ax]
+ if not isinstance(s, tvm.tirx.IntImm):
Review Comment:

Use `tirx.IntImm` instead of `tvm.tirx.IntImm` for consistency with other
parts of the file (e.g., lines 2014, 2070) and because `tirx` is already
imported from `tvm` at line 29.
```suggestion
if not isinstance(s, tirx.IntImm):
```
##########
python/tvm/relax/frontend/torch/base_fx_graph_translator.py:
##########
@@ -1645,12 +1645,70 @@ def _sum(self, node: fx.Node) -> relax.Var:
return self.block_builder.emit(relax.op.sum(x, dim, keepdims=keepdim))
def _var(self, node: fx.Node) -> relax.Var:
+ # `aten.var.correction` (and decomposed `aten.std.*`) carries an
+ # optional `correction` kwarg whose `None` default means 1 (Bessel).
+ # Legacy fx `tensor.var(...)` calls go through the original path
+ # below to keep this fix narrowly scoped.
+ target = node.target
+ if getattr(target, "_overloadname", None) == "correction" or getattr(
+ target, "overload_name", None
+ ) == "correction":
+ return self._var_correction(node)
args = self.retrieve_args(node)
x = args[0]
dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
keepdim = args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
return self.block_builder.emit(relax.op.variance(x, dim,
keepdims=keepdim))
+ def _var_correction(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+ keepdim = node.kwargs.get("keepdim", False)
+ correction = node.kwargs.get("correction", None)
+ if correction is None:
+ correction = 1
Review Comment:

The `correction` argument is retrieved directly from `node.kwargs`, but it
might be a `fx.Node` in the FX graph (e.g., if it's a symbolic value or
computed by another operation). It should be resolved using
`self._retrieve_args` to ensure it's a constant value or handled appropriately,
as the subsequent logic expects a Python number.
##########
python/tvm/relax/frontend/torch/base_fx_graph_translator.py:
##########
@@ -1645,12 +1645,70 @@ def _sum(self, node: fx.Node) -> relax.Var:
return self.block_builder.emit(relax.op.sum(x, dim, keepdims=keepdim))
def _var(self, node: fx.Node) -> relax.Var:
+ # `aten.var.correction` (and decomposed `aten.std.*`) carries an
+ # optional `correction` kwarg whose `None` default means 1 (Bessel).
+ # Legacy fx `tensor.var(...)` calls go through the original path
+ # below to keep this fix narrowly scoped.
+ target = node.target
+ if getattr(target, "_overloadname", None) == "correction" or getattr(
+ target, "overload_name", None
+ ) == "correction":
+ return self._var_correction(node)
args = self.retrieve_args(node)
x = args[0]
dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
keepdim = args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
return self.block_builder.emit(relax.op.variance(x, dim,
keepdims=keepdim))
+ def _var_correction(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+ keepdim = node.kwargs.get("keepdim", False)
+ correction = node.kwargs.get("correction", None)
+ if correction is None:
+ correction = 1
+ var = self.block_builder.emit(relax.op.variance(x, dim,
keepdims=keepdim))
+ if correction == 0:
+ return var
+ n = self._reduction_size(x, dim)
+ if n is None:
+ raise NotImplementedError(
+ "var/std with non-zero correction requires statically known "
+ "reduction-axis sizes."
+ )
+ # PyTorch returns NaN (with a warning) when `n - correction <= 0`;
+ # mirror that semantics rather than failing the import.
+ if n - correction <= 0:
+ scale = float("nan")
+ else:
+ scale = float(n) / float(n - correction)
+ return self.block_builder.emit(
+ relax.op.multiply(var, relax.const(scale, x.struct_info.dtype))
+ )
+
+ @staticmethod
+ def _reduction_size(x: relax.Expr, dim) -> int | None:
+ """Static product of reduced-axis sizes; None if any axis is
dynamic."""
+ shape = x.struct_info.shape
+ if shape is None:
+ return None
+ rank = len(shape)
+ if dim is None:
+ axes = list(range(rank))
+ elif isinstance(dim, int):
+ axes = [dim]
+ else:
+ axes = list(dim)
Review Comment:

If `dim` is a `relax.Expr` (which happens if it was a `fx.Node` in the FX
graph), `list(dim)` will raise a `TypeError`. Since this method is intended to
compute a static reduction size, it should explicitly check if `dim` is a
constant integer or a collection of integers, and return `None` otherwise to
indicate that the size cannot be determined statically.
```suggestion
if dim is None:
axes = list(range(rank))
elif isinstance(dim, int):
axes = [dim]
elif isinstance(dim, (list, tuple)) and all(isinstance(a, int) for a
in dim):
axes = list(dim)
else:
return None
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]