damccorm commented on code in PR #25321:
URL: https://github.com/apache/beam/pull/25321#discussion_r1098766011
##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -174,6 +186,9 @@ def __init__(
Otherwise, it will be CPU.
inference_fn: the inference function to use during RunInference.
default=_default_tensor_inference_fn
+ use_torch_script_format: When `use_torch_script_format` is set to `True`,
+ the model will be loaded using `torch.jit.load()`.
+ `model_class` and `model_params` arguments will be disregarded.
Review Comment:
> I think the use_torch_script_format is a wrong name. I can rename to
something like load_as_torchscript_model or any suggestions?
Maybe just `is_torchscript`? An alternate approach would be to take a
`torchscript_model_path` parameter. So instead of a bool, it would take a
string and we would require `state_dict_path` and `model_class` to be `None`. I
might lean towards that approach.
> When user enables this, we call torch.jit.load(), which accepts .pt, .pth
and also the new zip format torch is going to add as default soon
https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference
This still needs to be saved with the torchscript format, though, right? It
would not work with `.pt`/`.pth` file that just contains the state_dict
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]