jrmccluskey commented on code in PR #24062: URL: https://github.com/apache/beam/pull/24062#discussion_r1021725284
########## sdks/python/apache_beam/ml/inference/pytorch_inference.py: ########## @@ -100,6 +112,21 @@ def _convert_to_result( return [PredictionResult(x, y) for x, y in zip(batch, predictions)] +def default_tensor_inference_fn( + batch: Sequence[torch.Tensor], + model: torch.nn.Module, + device: str, + inference_args: Optional[Dict[str, + Any]] = None) -> Iterable[PredictionResult]: + # torch.no_grad() mitigates GPU memory issues + # https://github.com/apache/beam/issues/22811 + with torch.no_grad(): + batched_tensors = torch.stack(batch) + batched_tensors = _convert_to_device(batched_tensors, device) + predictions = model(batched_tensors, **inference_args) + return _convert_to_result(batch, predictions) Review Comment: The helper route isn't that bad, a little verbose on our end but it should make user routing easier. First run at it was just pushed. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
