yeandy commented on code in PR #22795:
URL: https://github.com/apache/beam/pull/22795#discussion_r956417937


##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -188,20 +216,24 @@ def __init__(
         Otherwise, it will be CPU.
     """
     self._state_dict_path = state_dict_path
-    if device == 'GPU' and torch.cuda.is_available():
+    if device == 'GPU':
+      logging.info("Device is set to CUDA")

Review Comment:
   I wanted to be super explicit about the detection of accelerators so users 
wouldn't have to second guess the type of environment they were operating in. 
(due to the nature of having multiple workers in DataFlow, I can imagine there 
would be future scenarios in which we have some GPU inconsistencies across the 
workers)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@beam.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to