This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 9b6c805e1cd Use correct device ordinal when GPU is detected (#31951)
9b6c805e1cd is described below
commit 9b6c805e1cdb911eb11f5ade137a92b087edcf86
Author: Danny McCormick <[email protected]>
AuthorDate: Tue Jul 23 19:01:06 2024 +0200
Use correct device ordinal when GPU is detected (#31951)
---
sdks/python/apache_beam/ml/inference/huggingface_inference.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference.py
b/sdks/python/apache_beam/ml/inference/huggingface_inference.py
index 91efcdd76a2..2934a536291 100644
--- a/sdks/python/apache_beam/ml/inference/huggingface_inference.py
+++ b/sdks/python/apache_beam/ml/inference/huggingface_inference.py
@@ -677,7 +677,7 @@ class HuggingFacePipelineModelHandler(ModelHandler[str,
self._load_pipeline_args['device'] = 'cpu'
else:
if is_gpu_available_torch():
- self._load_pipeline_args['device'] = 'cuda:1'
+ self._load_pipeline_args['device'] = 'cuda:0'
else:
_LOGGER.warning(
"HuggingFaceModelHandler specified a 'GPU' device, "