damccorm commented on code in PR #24062: URL: https://github.com/apache/beam/pull/24062#discussion_r1020571767
########## sdks/python/apache_beam/ml/inference/pytorch_inference.py: ########## @@ -100,6 +112,21 @@ def _convert_to_result( return [PredictionResult(x, y) for x, y in zip(batch, predictions)] +def default_tensor_inference_fn( + batch: Sequence[torch.Tensor], + model: torch.nn.Module, + device: str, + inference_args: Optional[Dict[str, + Any]] = None) -> Iterable[PredictionResult]: + # torch.no_grad() mitigates GPU memory issues + # https://github.com/apache/beam/issues/22811 + with torch.no_grad(): + batched_tensors = torch.stack(batch) + batched_tensors = _convert_to_device(batched_tensors, device) + predictions = model(batched_tensors, **inference_args) + return _convert_to_result(batch, predictions) Review Comment: > although you get a funky named function defined within the scope of the convenience function. No multi-line lambdas will do that. The function winds up looking something like this: Yeah, that's what I was imagining here. Note that while this is funky for us, its a smooth user experience because they don't have to be aware of the get_attr call. > Although the more I think about it, the more I think we should just provide a generate function users can pass since we know using that routing instead is a major motivating example for this feature. I agree, though I think `generate` is just an example and we want this to be as easy as possible for similar functions as well. I'd probably argue that a generic `make_tensor_inference_override_fn` (maybe with better naming 😅) does that better than us trying to hit all the functions a user could provide. We don't necessarily have to do that, but if you disagree let's get more peoples' voices/opinions involved. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
