damccorm commented on code in PR #24062:
URL: https://github.com/apache/beam/pull/24062#discussion_r1020392397


##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -100,6 +112,21 @@ def _convert_to_result(
   return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
 
 
+def default_tensor_inference_fn(
+    batch: Sequence[torch.Tensor],
+    model: torch.nn.Module,
+    device: str,
+    inference_args: Optional[Dict[str,
+                                  Any]] = None) -> Iterable[PredictionResult]:
+  # torch.no_grad() mitigates GPU memory issues
+  # https://github.com/apache/beam/issues/22811
+  with torch.no_grad():
+    batched_tensors = torch.stack(batch)
+    batched_tensors = _convert_to_device(batched_tensors, device)
+    predictions = model(batched_tensors, **inference_args)
+    return _convert_to_result(batch, predictions)

Review Comment:
   Looking at this, my only concern is that doing the simple thing (calling 
`model.generate(...)` instead of `model(...)` or something similar) is now 
harder to do performantly.
   
   I wonder if there's a way to still make that easy - one option would be to 
add something like a convenience function 
`make_tensor_inference_override_fn(model_function: string)` that generates a 
lambda. So the user code would then look like: `my_tensor_inference_fn = 
make_tensor_inference_overide_fn('generate')`, which would set 
`my_tensor_inference_fn` to an anonymous function that does:
   
   ```
     with torch.no_grad():
       batched_tensors = torch.stack(batch)
       batched_tensors = _convert_to_device(batched_tensors, device)
       predictions = model.generate(batched_tensors, **inference_args)
       return _convert_to_result(batch, predictions)
   ```
   
   Does that make sense/sound reasonable?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to