jrmccluskey commented on code in PR #24062: URL: https://github.com/apache/beam/pull/24062#discussion_r1020461900
########## sdks/python/apache_beam/ml/inference/pytorch_inference.py: ########## @@ -100,6 +112,21 @@ def _convert_to_result( return [PredictionResult(x, y) for x, y in zip(batch, predictions)] +def default_tensor_inference_fn( + batch: Sequence[torch.Tensor], + model: torch.nn.Module, + device: str, + inference_args: Optional[Dict[str, + Any]] = None) -> Iterable[PredictionResult]: + # torch.no_grad() mitigates GPU memory issues + # https://github.com/apache/beam/issues/22811 + with torch.no_grad(): + batched_tensors = torch.stack(batch) + batched_tensors = _convert_to_device(batched_tensors, device) + predictions = model(batched_tensors, **inference_args) + return _convert_to_result(batch, predictions) Review Comment: The convenience function kind of works, although you get a funky named function defined within the scope of the convenience function. No multi-line lambdas will do that. The function winds up looking something like this: ``` with torch.no_grad(): batched_tensors = torch.stack(batch) batched_tensors = _convert_to_device(batched_tensors, device) pred_fn = model.get_attr(model_fn) predictions = pred_fn(batched_tensors, **inference_args) return _convert_to_result(batch, predictions) ``` Although the more I think about it, the more I think we should just provide a `generate` function users can pass since we know using that routing instead is a major motivating example for this feature. There's also something to be sad for making `_convert_to_device` and `_convert_to_result` available to users as building blocks for their custom functions. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
