[
https://issues.apache.org/jira/browse/MXNET-507?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
]
Xingjian Shi updated MXNET-507:
-------------------------------
Summary: Two problems in the ordering operators (topk, sort, argsort)
(was: Two problems in our ordering operators (topk, sort, argsort))
> Two problems in the ordering operators (topk, sort, argsort)
> ------------------------------------------------------------
>
> Key: MXNET-507
> URL: https://issues.apache.org/jira/browse/MXNET-507
> Project: Apache MXNet
> Issue Type: Bug
> Reporter: Xingjian Shi
> Priority: Major
>
> There are two problems in the ordering operators, i.e, topk, sort, argsort:
> 1) Only real_t is supported.
> 2) The indices are stored as real_t. This will cause error in the backward
> pass where the gradient are passed to the wrong locations.
> For example, the following code cannot be run in the previous version:
> ```python
> import mxnet as mx
> import numpy as np
> import mxnet.ndarray as nd
> ctx = mx.cpu()
> a = mx.nd.arange(54686454, ctx=ctx, dtype=np.int32)
> a.attach_grad()
> k = 10
> with mx.autograd.record():
> b = mx.nd.topk(a, k=k, ret_typ='value')
> b.backward(mx.nd.ones((k,), ctx=ctx, dtype=np.int32))
> a_grad = a.grad.asnumpy()
> for i in range(-1, - k - 1, -1):
> assert a_grad[i] == 1
> ```
>
> I propose to fix this bug by changing the dtype of the indices to int32.
> However, this will make the code to be backward incompatible.
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]