sxjscience opened a new issue #19043:
URL: https://github.com/apache/incubator-mxnet/issues/19043
The gradient of `np.pad` is wrong. See the following reproducible example:
MXNet Implementation:
```python
import mxnet as mx
mx.npx.set_np()
ctx = mx.gpu()
a = mx.np.ones((3, 3, 3), ctx=ctx)
mult = np.random.normal(0, 1, (3, 3, 3))
a.attach_grad()
with mx.autograd.record():
b = mx.np.pad(a[:, 1:], ((0, 0), (0, 1), (0, 0))) * mx.np.array(mult,
ctx=ctx)
b = b.sum()
b.backward()
print(a.grad)
```
Output:
```
[[[0. 0. 0.]
[1. 1. 1.]
[1. 1. 1.]]
[[0. 0. 0.]
[1. 1. 1.]
[1. 1. 1.]]
[[0. 0. 0.]
[1. 1. 1.]
[1. 1. 1.]]] @gpu(0)
```
Jax Implementation:
```python
from jax import grad
import jax.numpy as jnp
import numpy as np
mult = np.random.normal(0, 1, (3, 3, 3))
a = jnp.ones((3, 3, 3))
def f(x):
b = jnp.pad(x[:, 1:], ((0, 0), (0, 1), (0, 0))) * jnp.array(mult)
return b.sum()
print(grad(f)(a))
```
Output:
```
[[[ 0. 0. 0. ]
[ 0.3545383 -0.84326786 -0.31482664]
[ 1.0994871 -1.230104 2.8007567 ]]
[[ 0. 0. 0. ]
[ 1.0447861 -0.16119051 -0.39860427]
[-0.7756538 0.5314936 1.4601654 ]]
[[ 0. 0. 0. ]
[ 0.37878916 -2.0777514 0.96676654]
[ 0.45230922 0.3094176 -0.43687683]]]
```
Basically, the following line is not correct:
https://github.com/apache/incubator-mxnet/blob/b0c39f7ea983639093c63d7d2486bbef083a55d6/src/operator/numpy/np_pad_op-inl.h#L544-L551
We should change that to
```
KERNEL_ASSIGN(out[i], req, a[i]);
```
In addition, I do not know why we need the `using namespace mxnet_op; `.
@CassiniXu
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]