sxjscience commented on issue #17823: [Operator] Add `index_add` or `index_update` to numpy extension URL: https://github.com/apache/incubator-mxnet/issues/17823#issuecomment-613767769 @ZheyuYe @JiangZhaoh @yzhliu @haojin2 To understand the problem, let's consider two use cases. The first one can be solved via `gather_nd` and the second one cannot be solved via the existing MXNet. ## Take elements at specific locations from the input data `out[i, j, ...] = data[i, positions[i, j], ...]` In GluonNLP, the `positions` are masked locations in the input that we will need to calculate the loss. `data` is the mapped hidden states of the sequences. With advanced indexing + imperative API, we can do something like this: ```python import mxnet as mx mx.npx.set_np() data = mx.np.random.normal(0, 1, (5, 5, 5, 5)) positions = mx.np.random.randint(0, 5, (5, 4)) out = data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions] print(out.asnumpy().shape) ``` In order to make the network hybridizable, we can implement it via `gather_nd`: ```python @use_np def select_vectors_by_position(F, data, positions): """Select each batch with the given positions. Once advanced indexing can be hybridized, we can revise the implementation. out[i, j, :] = data[i, positions[i, j], :] Parameters ---------- F data Input tensor of contextualized token embeddings Shape (batch_size, seq_length, units) positions Input tensor of the positions. Shape (batch_size, num_sel_positions). For each sample in the batch, the values in this tensor must not exceed the length of the sequence. Returns ------- out The selection result. Shape (batch_size, num_sel_positions, units) """ # Here, we use gather_nd to select the output from data: # Need to compute # out[i, j, :] = in[i, masked_position[i, j], :] # Thus, construct a indices with shape [2, batch_size, num_masked_position], where # indices[0, i, j] = i # indices[1, i, j] = masked_position[i, j] # Then, out = gather_nd(in, indices) positions = positions.astype(np.int32) # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...] batch_idx = F.np.expand_dims(F.npx.arange_like(positions, axis=0), axis=1).astype(np.int32) batch_idx = batch_idx + F.np.zeros_like(positions) indices = F.np.stack([batch_idx, positions]) out = F.npx.gather_nd(data, indices) return out ``` ## Update elements at specific locations of the input data For example, if we need some selected locations and will need to replace the elements without own generated element, i.e., `data[i, positions[i, j], ...] = update_val[i, j, ...]` With advanced indexing + imperative API, we can do something like this: ```python import mxnet as mx import numpy.testing as npt mx.npx.set_np() data = mx.np.random.normal(0, 1, (5, 5, 5, 5)) positions = mx.np.random.randint(0, 5, (5, 4)) update_val = mx.np.random.normal(0, 1, (5, 4, 5, 5)) data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions] = update_val print(out.asnumpy().shape) # or do data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions] += update_val print(out.asnumpy().shape) ``` However, we cannot surround it with autograd ```python import mxnet as mx import numpy.testing as npt mx.npx.set_np() data = mx.np.random.normal(0, 1, (5, 5, 5, 5)) positions = mx.np.random.randint(0, 5, (5, 4)) update_val = mx.np.random.normal(0, 1, (5, 4, 5, 5)) data.attach_grad() with mx.autograd.record(): data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions] = update_val mx.npx.waitall() ``` Error message: ``` MXNetError: Traceback (most recent call last): File "src/imperative/imperative.cc", line 203 MXNetError: Check failed: AGInfo: :IsNone(*output): Assigning to NDArrays that are already in a computational graph will cause undefined behavior when evaluating gradients. Please call backward first to clear the graph or do this out side of a record section. Also note that you cannot use inplace operations like +=, *=, relu(x, out=x), y[idx]=x, etc inside a record section. ``` We will need to have a workaround solution to this use case.
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
