masahi edited a comment on pull request #7231:
URL: https://github.com/apache/tvm/pull/7231#issuecomment-756716845
For the first case,
```
def foo(x):
x[:5] = 0
return 2 * x
```
Using the two passes I mentioned, I get this graph:
```
graph(%x : Float(10:1, requires_grad=0, device=cpu)):
%1 : int = prim::Constant[value=0]() # test.py:8:0
%2 : int = prim::Constant[value=0]() # test.py:8:0
%3 : int = prim::Constant[value=5]() # test.py:8:0
%4 : int = prim::Constant[value=1]() # test.py:8:0
%5 : Float(5:1, requires_grad=0, device=cpu) = aten::slice(%x, %1, %2, %3,
%4) # test.py:8:0
%6 : Float(requires_grad=0, device=cpu) = prim::Constant[value={0}]() #
test.py:8:0
%7 : int[] = prim::ListConstruct()
%8 : Float(requires_grad=0, device=cpu) = aten::view(%6, %7) # test.py:8:0
%9 : int = prim::Constant[value=5]() # test.py:8:0
%10 : int[] = prim::ListConstruct(%9)
%11 : bool = prim::Constant[value=1]() # test.py:8:0
%12 : Float(5:0, requires_grad=0, device=cpu) = aten::expand(%8, %10, %11)
# test.py:8:0
%13 : bool = prim::Constant[value=0]()
%18 : Float(5:0, requires_grad=0, device=cpu) = aten::expand_as(%12, %5) #
test.py:8:0
%20 : int = prim::Constant[value=0]()
%21 : int = aten::size(%x, %20)
%22 : int = prim::Constant[value=4]()
%23 : None = prim::Constant()
%24 : None = prim::Constant()
%25 : None = prim::Constant()
%26 : Tensor = aten::arange(%21, %22, %23, %24, %25)
%27 : int = prim::Constant[value=0]()
%28 : Tensor = aten::slice(%26, %27, %2, %3, %4)
%30 : int[] = prim::Constant[value=[-1]]()
%31 : Tensor = aten::view(%28, %30)
%32 : Tensor?[] = prim::ListConstruct(%31)
%33 : Float(5:1, requires_grad=0, device=cpu) = aten::index_put(%x, %32,
%18, %13)
%15 : Long(requires_grad=0, device=cpu) = prim::Constant[value={2}]() #
test.py:9:0
%16 : Float(10:1, requires_grad=0, device=cpu) = aten::mul(%33, %15) #
test.py:9:0
return (%16)
```
i.e. It seems `torch._C._jit_pass_onnx_prepare_inplace_ops_for_onnx` does
some of the requirement you mentioned (first and third, not sure about second)?
If we convert this graph to relay using `scatter_nd`, I think the output
tensor would be correct but we cannot modify the input like torch does.
But right, your second example results in
```
graph(%x : Float(10:1, requires_grad=0, device=cpu)):
%20 : Long(requires_grad=0, device=cpu) = prim::Constant[value={2}]() #
test.py:15:0
%21 : Float(10:1, requires_grad=0, device=cpu) = aten::mul(%x, %20) #
test.py:15:0
return (%21)
```
because `torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects` will
consider side-effect only operation as dead code (`y` is not used to compute
output).
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]