leandron commented on a change in pull request #7366:
URL: https://github.com/apache/tvm/pull/7366#discussion_r568443310
##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -285,17 +291,18 @@ def suffixes():
# Torch Script is a zip file, but can be named pth
return ["pth", "zip"]
- def load(self, path):
+ def load(self, path, shape_dict=None):
# pylint: disable=C0415
import torch
traced_model = torch.jit.load(path)
+ traced_model.eval() # Switch to inference mode
- inputs = list(traced_model.graph.inputs())[1:]
- input_shapes = [inp.type().sizes() for inp in inputs]
+ if shape_dict is None:
+ raise TVMCException("--input-shapes must be specified for %s" %
self.name())
Review comment:
A suggestion here would be moving this validation/error msg to be
checked before the model is loaded (line 297 maybe?), just so that the user
don't wait the time for a (sometimes big) model to be loaded, just to be
presented with an error message that we could've checked already.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]