Xingjian Shi created MXNET-507:
----------------------------------
Summary: Two problems in our ordering operators (topk, sort,
argsort)
Key: MXNET-507
URL: https://issues.apache.org/jira/browse/MXNET-507
Project: Apache MXNet
Issue Type: Bug
Reporter: Xingjian Shi
There are two problems in the ordering operators, i.e, topk, sort, argsort:
1) Only real_t is supported.
2) The indices are stored as real_t. This will cause error in the backward pass
where the gradient are passed to the wrong locations.
For example, the following code cannot be run in the previous version:
```python
import mxnet as mx
import numpy as np
import mxnet.ndarray as nd
ctx = mx.cpu()
a = mx.nd.arange(54686454, ctx=ctx, dtype=np.int32)
a.attach_grad()
k = 10
with mx.autograd.record():
b = mx.nd.topk(a, k=k, ret_typ='value')
b.backward(mx.nd.ones((k,), ctx=ctx, dtype=np.int32))
a_grad = a.grad.asnumpy()
for i in range(-1, - k - 1, -1):
assert a_grad[i] == 1
```
I propose to fix this bug by changing the dtype of the indices to int32.
However, this will make the code to be backward incompatible.
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]