masahi commented on a change in pull request #4944: [Relay, Torch] Clean up and
refactor PyTorch frontend
URL: https://github.com/apache/incubator-tvm/pull/4944#discussion_r384779098
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -1016,17 +978,58 @@ def from_pytorch(script_module, input_shapes):
TorchScripted PyTorch graph
Note: We currently only support traces (ie: torch.jit.trace(model,
input))
- shape : Dictionary of input dimensions
+ input_shape : Dictionary of input dimensions
Graph level input shape dictionary
+ The keys should be the same one returned by get_graph_input_names(...)
above
Returns
-------
mod : tvm.relay.Module
The module that optimizations will be performed on.
- params : dict of str to tvm.runtime
- Dict of converted parameters stored in tvm.runtime format
+ params : dict of str to tvm.runtime.NDArray
+ Dict of converted parameters stored in tvm.runtime.ndarray format
"""
- g = Graph(script_module, input_shapes)
- mod, params = g.from_pytorch()
- return mod, params
+ graph = script_module.graph.copy()
Review comment:
Because we overwrite the input graph in place. This is not new in this PR.
Currently we apply the torch inlineing pass in `run_jit_passes(...)` below.
I also do some surgery on the input graph in my QNN implementation.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services