zehuichen123 opened a new issue #18519:
URL: https://github.com/apache/incubator-mxnet/issues/18519
Hi, I am trying to define a custom op to replace `SoftmaxOutput` since I
want to weighted loss value before mx.make_loss. However, I encountered the
Exception shown below:
```
Traceback (most recent call last):
File "detection_train.py", line 307, in <module>
train_net(parse_args())
File "detection_train.py", line 289, in train_net
profile=profile
File "/mnt/truenas/scratch/czh/kl_baseline/core/detection_module.py", line
1010, in fit
Error in CustomOp.backward: Traceback (most recent call last):
File ".../mxnet_xyxy/python/mxnet/operator.py", line 1022, in
backward_entry
stype=stype))
File ".../mxnet_xyxy/python/mxnet/ndarray/sparse.py", line 1187, in
_ndarray_cls
raise Exception("unknown storage type: %s"%stype)
Exception: unknown storage type: -1
```
The version of mxnet is 1.6.0.
Here is the custom op I defined:
```
import numpy as np
import mxnet as mx
import time
class WeightSoftmaxOperator(mx.operator.CustomOp):
def __init__(self):
super().__init__()
def forward(self, is_train, req, in_data, out_data, aux):
data = in_data[0]
label = in_data[1]
data = mx.nd.softmax(data)
loss = - label * mx.nd.log(data + 1e-8)
self.assign(out_data[0], req[0], loss)
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
data = in_data[0]
label = in_data[1]
d_grad = mx.nd.softmax(data) - label
self.assign(in_grad[0], req[0], d_grad)
self.assign(in_grad[1], req[1], mx.nd.zeros_like(label))
@mx.operator.register('weight_softmax')
class WeightSoftmaxProp(mx.operator.CustomOpProp):
def __init__(self):
super().__init__(need_top_grad=True)
def list_arguments(self):
return ['data', 'label']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
return [in_shape[0], in_shape[1]], [in_shape[0]]
def create_operator(self, ctx, shapes, dtypes):
return WeightSoftmaxOperator()
def declare_backward_dependency(self, out_grad, in_data, out_data):
deps = []
if self.need_top_grad_:
deps.extend(out_grad)
deps.extend(in_data)
deps.extend(out_data)
return deps
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]