gemini-code-assist[bot] commented on code in PR #19512:
URL: https://github.com/apache/tvm/pull/19512#discussion_r3194102120


##########
python/tvm/relax/frontend/torch/base_fx_graph_translator.py:
##########
@@ -1645,12 +1645,70 @@ def _sum(self, node: fx.Node) -> relax.Var:
         return self.block_builder.emit(relax.op.sum(x, dim, keepdims=keepdim))
 
     def _var(self, node: fx.Node) -> relax.Var:
+        # `aten.var.correction` (and decomposed `aten.std.*`) carries an
+        # optional `correction` kwarg whose `None` default means 1 (Bessel).
+        # Legacy fx `tensor.var(...)` calls go through the original path
+        # below to keep this fix narrowly scoped.
+        target = node.target
+        if getattr(target, "_overloadname", None) == "correction" or getattr(
+            target, "overload_name", None
+        ) == "correction":
+            return self._var_correction(node)
         args = self.retrieve_args(node)
         x = args[0]
         dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
         keepdim = args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
         return self.block_builder.emit(relax.op.variance(x, dim, 
keepdims=keepdim))
 
+    def _var_correction(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+        keepdim = node.kwargs.get("keepdim", False)
+        correction = node.kwargs.get("correction", None)
+        if correction is None:
+            correction = 1
+        var = self.block_builder.emit(relax.op.variance(x, dim, 
keepdims=keepdim))
+        if correction == 0:
+            return var
+        n = self._reduction_size(x, dim)
+        if n is None:
+            raise NotImplementedError(
+                "var/std with non-zero correction requires statically known "
+                "reduction-axis sizes."
+            )
+        # PyTorch returns NaN (with a warning) when `n - correction <= 0`;
+        # mirror that semantics rather than failing the import.
+        if n - correction <= 0:
+            scale = float("nan")
+        else:
+            scale = float(n) / float(n - correction)
+        return self.block_builder.emit(
+            relax.op.multiply(var, relax.const(scale, x.struct_info.dtype))
+        )
+
+    @staticmethod
+    def _reduction_size(x: relax.Expr, dim) -> int | None:
+        """Static product of reduced-axis sizes; None if any axis is 
dynamic."""
+        shape = x.struct_info.shape
+        if shape is None:
+            return None
+        rank = len(shape)
+        if dim is None:
+            axes = list(range(rank))
+        elif isinstance(dim, int):
+            axes = [dim]
+        else:
+            axes = list(dim)
+        n = 1
+        for ax in axes:
+            ax = ax + rank if ax < 0 else ax
+            s = shape[ax]
+            if not isinstance(s, tvm.tirx.IntImm):

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Use `tirx.IntImm` instead of `tvm.tirx.IntImm` for consistency with other 
parts of the file (e.g., lines 2014, 2070) and because `tirx` is already 
imported from `tvm` at line 29.
   
   ```suggestion
               if not isinstance(s, tirx.IntImm):
   ```



##########
python/tvm/relax/frontend/torch/base_fx_graph_translator.py:
##########
@@ -1645,12 +1645,70 @@ def _sum(self, node: fx.Node) -> relax.Var:
         return self.block_builder.emit(relax.op.sum(x, dim, keepdims=keepdim))
 
     def _var(self, node: fx.Node) -> relax.Var:
+        # `aten.var.correction` (and decomposed `aten.std.*`) carries an
+        # optional `correction` kwarg whose `None` default means 1 (Bessel).
+        # Legacy fx `tensor.var(...)` calls go through the original path
+        # below to keep this fix narrowly scoped.
+        target = node.target
+        if getattr(target, "_overloadname", None) == "correction" or getattr(
+            target, "overload_name", None
+        ) == "correction":
+            return self._var_correction(node)
         args = self.retrieve_args(node)
         x = args[0]
         dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
         keepdim = args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
         return self.block_builder.emit(relax.op.variance(x, dim, 
keepdims=keepdim))
 
+    def _var_correction(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+        keepdim = node.kwargs.get("keepdim", False)
+        correction = node.kwargs.get("correction", None)
+        if correction is None:
+            correction = 1

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The `correction` argument is retrieved directly from `node.kwargs`, but it 
might be a `fx.Node` in the FX graph (e.g., if it's a symbolic value or 
computed by another operation). It should be resolved using 
`self._retrieve_args` to ensure it's a constant value or handled appropriately, 
as the subsequent logic expects a Python number.



##########
python/tvm/relax/frontend/torch/base_fx_graph_translator.py:
##########
@@ -1645,12 +1645,70 @@ def _sum(self, node: fx.Node) -> relax.Var:
         return self.block_builder.emit(relax.op.sum(x, dim, keepdims=keepdim))
 
     def _var(self, node: fx.Node) -> relax.Var:
+        # `aten.var.correction` (and decomposed `aten.std.*`) carries an
+        # optional `correction` kwarg whose `None` default means 1 (Bessel).
+        # Legacy fx `tensor.var(...)` calls go through the original path
+        # below to keep this fix narrowly scoped.
+        target = node.target
+        if getattr(target, "_overloadname", None) == "correction" or getattr(
+            target, "overload_name", None
+        ) == "correction":
+            return self._var_correction(node)
         args = self.retrieve_args(node)
         x = args[0]
         dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
         keepdim = args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
         return self.block_builder.emit(relax.op.variance(x, dim, 
keepdims=keepdim))
 
+    def _var_correction(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+        keepdim = node.kwargs.get("keepdim", False)
+        correction = node.kwargs.get("correction", None)
+        if correction is None:
+            correction = 1
+        var = self.block_builder.emit(relax.op.variance(x, dim, 
keepdims=keepdim))
+        if correction == 0:
+            return var
+        n = self._reduction_size(x, dim)
+        if n is None:
+            raise NotImplementedError(
+                "var/std with non-zero correction requires statically known "
+                "reduction-axis sizes."
+            )
+        # PyTorch returns NaN (with a warning) when `n - correction <= 0`;
+        # mirror that semantics rather than failing the import.
+        if n - correction <= 0:
+            scale = float("nan")
+        else:
+            scale = float(n) / float(n - correction)
+        return self.block_builder.emit(
+            relax.op.multiply(var, relax.const(scale, x.struct_info.dtype))
+        )
+
+    @staticmethod
+    def _reduction_size(x: relax.Expr, dim) -> int | None:
+        """Static product of reduced-axis sizes; None if any axis is 
dynamic."""
+        shape = x.struct_info.shape
+        if shape is None:
+            return None
+        rank = len(shape)
+        if dim is None:
+            axes = list(range(rank))
+        elif isinstance(dim, int):
+            axes = [dim]
+        else:
+            axes = list(dim)

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   If `dim` is a `relax.Expr` (which happens if it was a `fx.Node` in the FX 
graph), `list(dim)` will raise a `TypeError`. Since this method is intended to 
compute a static reduction size, it should explicitly check if `dim` is a 
constant integer or a collection of integers, and return `None` otherwise to 
indicate that the size cannot be determined statically.
   
   ```suggestion
           if dim is None:
               axes = list(range(rank))
           elif isinstance(dim, int):
               axes = [dim]
           elif isinstance(dim, (list, tuple)) and all(isinstance(a, int) for a 
in dim):
               axes = list(dim)
           else:
               return None
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to