jhlee525 opened a new pull request, #15502:
URL: https://github.com/apache/tvm/pull/15502
Although #9375 has been rejected, I tried a different way to support
`aten::copy_` op.
`aten::copy_` behaves differently from other inplace ops, "pure inplace"
way, unlike other inplace nodes' one, which output graph(`torch.Graph`) still
relaying it's output to users so that a DAG can be structed. However,
`aten::copy_` op returns itself, which dangles all of mutations.
For example, a torch module like
```
class Test(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor):
x[:5, :5] = x[:5, :5] + 1
return x
```
generates the graph
```
graph(%self : __torch__.Test,
%x : Float(10, 10, strides=[10, 1], requires_grad=0, device=cpu)):
%4 : int = prim::Constant[value=0]() # /home/jhlee/tvm/test.py:10:0
%5 : int = prim::Constant[value=0]() # /home/jhlee/tvm/test.py:10:0
%6 : int = prim::Constant[value=5]() # /home/jhlee/tvm/test.py:10:0
%7 : int = prim::Constant[value=1]() # /home/jhlee/tvm/test.py:10:0
%8 : Float(5, 10, strides=[10, 1], requires_grad=0, device=cpu) =
aten::slice(%x, %4, %5, %6, %7)
%9 : int = prim::Constant[value=1]() # /home/jhlee/tvm/test.py:10:0
%10 : int = prim::Constant[value=0]() # /home/jhlee/tvm/test.py:10:0
%11 : int = prim::Constant[value=5]() # /home/jhlee/tvm/test.py:10:0
%12 : int = prim::Constant[value=1]() # /home/jhlee/tvm/test.py:10:0
%13 : Float(5, 5, strides=[10, 1], requires_grad=0, device=cpu) =
aten::slice(%8, %9, %10, %11, %12)
%14 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
%15 : int = prim::Constant[value=1]() # /home/jhlee/tvm/test.py:10:0
%16 : Float(5, 5, strides=[5, 1], requires_grad=0, device=cpu) =
aten::add(%13, %14, %15)
%17 : int = prim::Constant[value=0]() # /home/jhlee/tvm/test.py:10:0
%18 : int = prim::Constant[value=0]() # /home/jhlee/tvm/test.py:10:0
%19 : int = prim::Constant[value=5]() # /home/jhlee/tvm/test.py:10:0
%20 : int = prim::Constant[value=1]() # /home/jhlee/tvm/test.py:10:0
%21 : Float(5, 10, strides=[10, 1], requires_grad=0, device=cpu) =
aten::slice(%x, %17, %18, %19, %20)
%22 : int = prim::Constant[value=1]() # /home/jhlee/tvm/test.py:10:0
%23 : int = prim::Constant[value=0]() # /home/jhlee/tvm/test.py:10:0
%24 : int = prim::Constant[value=5]() # /home/jhlee/tvm/test.py:10:0
%25 : int = prim::Constant[value=1]() # /home/jhlee/tvm/test.py:10:0
%26 : Float(5, 5, strides=[10, 1], requires_grad=0, device=cpu) =
aten::slice(%21, %22, %23, %24, %25)
%27 : bool = prim::Constant[value=0]()
%28 : Float(5, 5, strides=[10, 1], requires_grad=0, device=cpu) =
aten::copy_(%26, %16, %27)
return (%x)
```
which returns `%x` itself.
My approach to handle this problem is:
1. in `from_pytorch`, insert a pass that redirects output of
`aten::copy_`(`_redirect_inplace_output`), after `_run_jit_passes` is called,
in *torch level(`torch.Graph`)*
2. when handling `aten::copy` node, we collect from it's parents to collect
`aten::select` and `aten::slice` nodes, to generate indices of source. I
referenced pytorch repository, behavior of [torch -> onnx
conversion](https://github.com/pytorch/pytorch/blob/v2.0.1/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp#L175)
I'm not familiar with making a PR to this repository, so please let me know
if there is any feedbacks or questions.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]