leandron commented on a change in pull request #7366:
URL: https://github.com/apache/tvm/pull/7366#discussion_r568443310



##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -285,17 +291,18 @@ def suffixes():
         # Torch Script is a zip file, but can be named pth
         return ["pth", "zip"]
 
-    def load(self, path):
+    def load(self, path, shape_dict=None):
         # pylint: disable=C0415
         import torch
 
         traced_model = torch.jit.load(path)
+        traced_model.eval()  # Switch to inference mode
 
-        inputs = list(traced_model.graph.inputs())[1:]
-        input_shapes = [inp.type().sizes() for inp in inputs]
+        if shape_dict is None:
+            raise TVMCException("--input-shapes must be specified for %s" % 
self.name())

Review comment:
       A suggestion here would be moving this validation/error msg to be 
checked before the model is loaded (line  297 maybe?), just so that the user 
don't wait the time for a (sometimes big) model to be loaded, just to be 
presented with an error message that we could've checked already.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to