apivovarov commented on pull request #7513:
URL: https://github.com/apache/tvm/pull/7513#issuecomment-786330018
I prepared a test model which uses both `destination tensor` and `copy_
output tensor` after the `copy_` operator.
PyTorch and TVM outputs are the same
What other test models can we try?
```
import torch
import tvm
from tvm import relay
import numpy as np
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
def forward(self, values):
scale = values.shape[0]
A = torch.zeros(values.shape)
B = torch.stack([A] * scale)
V1 = B + 1
C = B.copy_(values)
V2 = B + 2
V3 = C + 3
D = V1 + V2 + V3
return D
net = Net()
a = torch.tensor([0, 1, 2, 6])
net(a)
traced_net = torch.jit.trace(net, (a))
ctx = tvm.cpu(0)
target = 'llvm'
shape_list = [("input0", [4,]),]
mod, params = relay.frontend.from_pytorch(traced_net, shape_list)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
func=mod['main']
intrp = relay.create_executor("graph", ctx=ctx, target=target)
ff=intrp.evaluate(func)
ff([0, 1, 2, 6])
<tvm.nd.NDArray shape=(4, 4), cpu(0)>
array([[ 6., 8., 10., 18.],
[ 6., 8., 10., 18.],
[ 6., 8., 10., 18.],
[ 6., 8., 10., 18.]], dtype=float32)
```
Relay graph:
```
print(func)
fn (%input0: Tensor[(4), int64]) {
%0 = full(0, shape=[4], dtype="float32");
%1 = (%0, %0, %0, %0);
%2 = stack(%1);
%3 = add(%2, 1f);
%4 = cast(%input0, dtype="float32");
%5 = (%4, %4, %4, %4);
%6 = stack(%5);
%7 = add(%6, 2f);
%8 = add(%3, %7);
%9 = add(%6, 3f);
add(%8, %9)
}
```
Torch graph:
```
print(traced_net.graph)
graph(%self : __torch__.Net,
%values : Long(4, strides=[1], requires_grad=0, device=cpu)):
%7 : int = prim::Constant[value=0]() # <stdin>:6:0
%8 : int = aten::size(%values, %7) # <stdin>:6:0
%9 : Long(device=cpu) = prim::NumToTensor(%8)
%10 : int = aten::Int(%9)
%11 : int[] = prim::ListConstruct(%10)
%12 : int = prim::Constant[value=6]() # <stdin>:6:0
%13 : None = prim::Constant()
%14 : Device = prim::Constant[value="cpu"]() # <stdin>:6:0
%15 : bool = prim::Constant[value=0]() # <stdin>:6:0
%A : Float(4, strides=[1], requires_grad=0, device=cpu) = aten::zeros(%11,
%12, %13, %14, %15) # <stdin>:6:0
%17 : Tensor[] = prim::ListConstruct(%A, %A, %A, %A)
%18 : int = prim::Constant[value=0]() # <stdin>:7:0
%B.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) =
aten::stack(%17, %18) # <stdin>:7:0
%20 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() #
<stdin>:8:0
%21 : int = prim::Constant[value=1]() # <stdin>:8:0
%V1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) =
aten::add(%B.1, %20, %21) # <stdin>:8:0
%23 : bool = prim::Constant[value=0]()
%B : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) =
aten::copy_(%B.1, %values, %23) # <stdin>:9:0
%25 : Long(requires_grad=0, device=cpu) = prim::Constant[value={2}]() #
<stdin>:10:0
%26 : int = prim::Constant[value=1]() # <stdin>:10:0
%V2 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) =
aten::add(%B, %25, %26) # <stdin>:10:0
%28 : Long(requires_grad=0, device=cpu) = prim::Constant[value={3}]() #
<stdin>:11:0
%29 : int = prim::Constant[value=1]() # <stdin>:11:0
%V3 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) =
aten::add(%B, %28, %29) # <stdin>:11:0
%31 : int = prim::Constant[value=1]() # <stdin>:12:0
%32 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) =
aten::add(%V1, %V2, %31) # <stdin>:12:0
%33 : int = prim::Constant[value=1]() # <stdin>:12:0
%34 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) =
aten::add(%32, %V3, %33) # <stdin>:12:0
return (%34)
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]