hgt312 edited a comment on pull request #9375:
URL: https://github.com/apache/tvm/pull/9375#issuecomment-952479498
@comaniac @masahi I find that the output will not be correct due to
something like `a[...] = b`, like the previous issues.
In BART, it is from a function, the whole function is not inplace.
```
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int,
decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to
be defined."
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
```
Also, I find that after https://github.com/pytorch/pytorch/pull/52063 (torch
version >= 1.9), we can use
`torch._C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, None)` to move all
the `aten::copys`, then the corresponding part will look like:
```
%69 : Tensor = onnx::Placeholder[name="index_put_"](%62) #
<ipython-input-1-662caefe3c7e>:8:0
block0():
%70 : Float(3, 3, strides=[3, 1], requires_grad=0, device=cpu) =
aten::slice(%shifted_input_ids, %42, %43, %44, %45) #
<ipython-input-1-662caefe3c7e>:8:0
%71 : Float(3, strides=[3], requires_grad=0, device=cpu) =
aten::select(%70, %47, %48) # <ipython-input-1-662caefe3c7e>:8:0
%72 : Float(3, strides=[3], requires_grad=0, device=cpu) =
aten::index_put_(%71, %66, %67, %57) # <ipython-input-1-662caefe3c7e>:8:0
-> (%72)
```
and the subgraph can be convert to ONNX's `index_put`.
Maybe the torch->onnx path will work for these models?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]