wkcn commented on issue #18333: URL: https://github.com/apache/incubator-mxnet/issues/18333#issuecomment-631608741
Hi @John1231983 , using `mx.nd.pick` will simplify the code. ```python import mxnet as mx B, C, H, W = 2, 2, 3, 4 x = mx.random.uniform(-1, 1, shape=(B, C, H, W)) target = mx.random.randint(0, C, shape=(B, 1, H, W)) f = mx.nd.softmax(x) # target size of Bx1xHxW target_squeeze = mx.nd.squeeze(target, axis=1) #size of BxHxW target_squeeze = mx.nd.one_hot(target_squeeze, depth = 2, on_value = -1.0, off_value = 0.0) # Transpose from BxHxWx2 to Bx2xHxW target_squeeze = mx.nd.transpose(target_squeeze, axes=(0,3,1,2)) # Get log of feature f f_log = mx.nd.log(f) batch_size =32 f_sum = mx.nd.sum(target_squeeze * f_log)/batch_size print(f_sum) lscore = -mx.nd.log_softmax(x) target_squeeze = mx.nd.squeeze(target, axis=1) #size of BxHxW t_sum = mx.nd.pick(lscore, target_squeeze, axis=1).sum() / batch_size print(t_sum) # t_sum == f_sum ``` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
