JCBrouwer edited a comment on issue #10223: URL: https://github.com/apache/tvm/issues/10223#issuecomment-1051938367
Hello @masahi , sorry for the slow response, I somehow missed the notification on this one, thanks for enabling the op! I took a look at running on your branch and was also getting BAD_PARAMs on both the above test case and my full model. After a bit of mucking around I noticed [this change](https://github.com/apache/tvm/compare/main...masahi:conv2d-transpose-group-cudnn?expand=1#diff-ea7aa778342026f1671b0ec3af0e24e01aa4375532c7744e2274c20f885b89d6R303) is incorrect. The argument is the conv_mode which should be left as 1 (according to the main branch). Changing that back I'm able to run both the test case and my larger model with grouped conv2d_transpose ops on the CUDNN backend :tada: Sadly I'm still just a few FPS shy of my performance target so I'll have to keep on digging for speedups. RE: support for groups in the regular cuda backend. Do you have a general idea of what kind of changes are necessary for that? I'm no expert, but I might be able to figure it out if it's just adapting similar code from grouped conv2d to work for grouped conv2d_transpose. For good measure <details> <summary>The updated test code which now works with the one line change to your PR</summary> ```py import torch import tvm.relay from torch.nn.functional import conv_transpose2d from tvm import relay from tvm.contrib import graph_executor class ModulatedConvTranspose2D(torch.nn.Module): def forward(self, x, w, s): B, C, H, W = x.shape I, O, KH, KW = w.shape # weight is different for each input in batch (this is why we want grouped conv transpose) w = w.unsqueeze(0) * s.reshape(B, 1, 1, 1, 1) w = w.reshape(B * I, O, KH, KW) x = x.reshape(1, B * C, H, W) x = conv_transpose2d(x, w, stride=(2, 2), padding=(1, 1), output_padding=(1, 1), groups=B) x = x.reshape(B, O, H * 2, W * 2) return x with torch.inference_mode(): device = "cuda" target = "cuda -libs=cudnn" dtype = torch.float16 tvm_dtype = dtype.__repr__().split(".")[-1] b, c, h, w, k = 4, 512, 8, 16, 3 inputs = torch.rand((b, c, h, w), dtype=dtype, device=device) weights = torch.rand((c, c // 2, k, k), dtype=dtype, device=device) styles = torch.rand((b), dtype=dtype, device=device) torch_mod = torch.jit.trace(ModulatedConvTranspose2D().eval().to(device), (inputs, weights, styles)) outputs_torch = torch_mod(inputs, weights, styles) print("Torch output shape", tuple(outputs_torch.shape)) # (4, 256, 16, 32) tvm_mod, tvm_params = relay.frontend.pytorch.from_pytorch( torch_mod, [ ("inputs", (tuple(inputs.shape), tvm_dtype)), ("weights", (tuple(weights.shape), tvm_dtype)), ("styles", (tuple(styles.shape), tvm_dtype)), ], ) with tvm.transform.PassContext(opt_level=10): lib = relay.build(tvm_mod, target=target, params=tvm_params) m = graph_executor.GraphModule(lib["default"](tvm.cuda())) m.run( inputs=tvm.nd.array(inputs.cpu(), device=tvm.cuda()), weights=tvm.nd.array(weights.cpu(), device=tvm.cuda()), styles=tvm.nd.array(styles.cpu(), device=tvm.cuda()), ) print("TVM output shape ", m.get_output(0).numpy().shape) # (4, 256, 16, 32) ``` </details> -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
