gemini-code-assist[bot] commented on code in PR #18609:
URL: https://github.com/apache/tvm/pull/18609#discussion_r2645657700
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -199,7 +199,7 @@ def _batch_norm_legit_no_stats(self, node: fx.Node) ->
relax.Var:
# Determine axes for instance norm (all spatial dimensions after
channel)
dim = len(self.shape_of(x))
- axes = list(range(2, dim))
+ axes = [0] + list(range(2, dim))
Review Comment:

The comment on line 200 is now misleading. It states that the axes are for
`instance_norm` and include 'all spatial dimensions after channel'. However,
with the change to include axis 0, it now also includes the batch dimension,
which makes it behave more like batch normalization. Please consider updating
the comment to reflect this, for example:
```python
# Determine axes for normalization, including batch and all spatial
dimensions.
```
##########
python/tvm/topi/nn/batch_norm.py:
##########
@@ -111,26 +111,22 @@ def batch_norm(
shape = [1] * len(data.shape)
shape[axis] = data.shape[axis]
- reduce_axes = list(range(len(data.shape)))
- reduce_axes.remove(axis)
- shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in
reduce_axes], 1)
-
- data_mean = topi.sum(data, axis=reduce_axes) / shape_prod
- data_mean_rs = topi.reshape(data_mean, shape)
- data_var = (
- topi.sum((data - data_mean_rs) * (data - data_mean_rs),
axis=reduce_axes) / shape_prod
- )
- data_var_rs = topi.reshape(data_var, shape)
-
if training:
+ reduce_axes = list(range(len(data.shape)))
+ reduce_axes.remove(axis)
+ shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in
reduce_axes], 1)
+ data_mean = topi.sum(data, axis=reduce_axes) / shape_prod
+ data_mean_rs = topi.reshape(data_mean, shape)
+ data_var = (
+ topi.sum((data - data_mean_rs) * (data - data_mean_rs),
axis=reduce_axes) / shape_prod
+ )
+ data_var_rs = topi.reshape(data_var, shape)
+ out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon)
+ else:
moving_mean_rs = topi.reshape(moving_mean, shape)
moving_var_rs = topi.reshape(moving_var, shape)
-
out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon)
Review Comment:

This change introduces a critical scoping bug. The variables `reduce_axes`,
`shape_prod`, `data_mean`, and `data_mean_rs` are now defined within the `if
training:` block, but they are used later in the function within another `if
training:` block (lines 135-144), which will cause a `NameError` as they will
be out of scope.
While the intention to compute these only during training is good for
performance, it breaks the existing structure of the function.
To fix this bug, these variable definitions need to be moved outside the `if
training:` block. This will revert the performance optimization but ensures
correctness. The suggested change below fixes the immediate bug.
A more comprehensive refactoring could merge the two `if training:` blocks
to achieve both correctness and performance.
Additionally, `data_var` is calculated twice during training (here and again
on line 137). This is redundant and could be optimized in a future refactoring.
```suggestion
reduce_axes = list(range(len(data.shape)))
reduce_axes.remove(axis)
shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in
reduce_axes], 1)
data_mean = topi.sum(data, axis=reduce_axes) / shape_prod
data_mean_rs = topi.reshape(data_mean, shape)
if training:
data_var = (
topi.sum((data - data_mean_rs) * (data - data_mean_rs),
axis=reduce_axes) / shape_prod
)
data_var_rs = topi.reshape(data_var, shape)
out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon)
else:
moving_mean_rs = topi.reshape(moving_mean, shape)
moving_var_rs = topi.reshape(moving_var, shape)
out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs +
epsilon)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]