junrushao1994 commented on a change in pull request #25: URL: https://github.com/apache/tvm-rfcs/pull/25#discussion_r698094475
########## File path: rfcs/0025-add-pytorch-tvm.md ########## @@ -0,0 +1,265 @@ +- Feature Name: PyTorchTVM +- Start Date: 2021-08-24 +- RFC PR: [apache/tvm-rfcs#0025](https://github.com/apache/tvm-rfcs/pull/25) +- GitHub Issue: TODO + +# Summary +[summary]: #summary + +This RFC add a `PyTorchTVM` module to support: compile TorchScript to TVM and use accelerated module in PyTorch. + +To increase the TVM accessibility for PyTorch users, we propose `PyTorchTVM` module to support the following workflow: +1. convert a torchscript module to tvm graph +2. build and tune tvm graph +3. export well-tuned tvm graph as a pytorch op +4. torch jit trace the tvm pytorch op with other pytorch modules, then save/load/serve as normal pytorch model + + + +# Motivation +[motivation]: #motivation + +PyTorch framework is increasingly being adopted for research and production. At the same time, PyTorch lacks an effective inference acceleration toolchain, which is the main concern in the industry. Existing acceleration includes: + +* PyTorch → ONNX → TensorRT/TVM +* PyTorch → torchscript → TensorRT/TVM + +From our perspective, there are some limitations for both ONNX and TensorRT: + +* Onnx cannot cover all models with dynamic control flow (e.g. for loop) +* TensorRT can only accelerate some standard networks + +So we hope to use TVM to accelerate PyTorch model inference. + + +# Guide-level explanation +[guide-level-explanation]: #guide-level-explanation + + +For example, we have an end-to-end resnet classification model, consisting of 3 parts: + +1. Image reader +2. Image transforms +3. Resnet model inference + +``` +class Predictor(nn.Module): + + def __init__(self, tvm_module=None): + super().__init__() + self.resnet18 = resnet18(pretrained=True, progress=False).eval() + self.transforms = nn.Sequential( + T.Resize([256, ]), # We use single int value inside a list due to torchscript type restrictions + T.CenterCrop(224), + T.ConvertImageDtype(torch.half), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ) + + def forward(self, image_path: List[str]) -> torch.Tensor: + with torch.no_grad(): + images: List[torch.Tensor] = [] + for path in image_path: + img = read_image(path) + images.append(img) + x = torch.stack(images).cuda().half() + x = self.transforms(x) + print(x.shape) + y_pred = self.resnet18(x) + return y_pred.argmax(dim=1) +``` + +We choose to accelerate resnet model with PyTorchTVM + +``` +from tvm.contrib.pt_op import PyTorchTVMModule, compile + +print("compile...") +option = { + "input_infos": [ + ("x", (1, 3, 224, 224)), + ], + "default_dtype": "float16", + "export_dir": "pytorch_compiled", + "num_outputs": 1, + "tuning_n_trials": 0, # set zero to skip tuning + "tuning_log_file": "tuning.log", +} +x = torch.randn(1, 3, 224, 224).cuda().half() +resnet_jit = torch.jit.trace(model.resnet18, x) +resnet_tvm = compile(resnet_jit, option) +``` + +Then we can use the accelerated tvm module directly in pytorch, and also use `torch.jit.script` together with the other 2 parts. Review comment: ```suggestion The TVM-accelerated `resnet_tvm` module can be used directly in PyTorch, or integrated into TorchScript with `torch.jit.script` along with all other PyTorch-native operations. ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
