[
https://issues.apache.org/jira/browse/BEAM-14337?focusedWorklogId=772800&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-772800
]
ASF GitHub Bot logged work on BEAM-14337:
-----------------------------------------
Author: ASF GitHub Bot
Created on: 20/May/22 12:24
Start Date: 20/May/22 12:24
Worklog Time Spent: 10m
Work Description: yeandy commented on code in PR #17470:
URL: https://github.com/apache/beam/pull/17470#discussion_r878088844
##########
sdks/python/apache_beam/ml/inference/pytorch.py:
##########
@@ -39,25 +42,47 @@ class PytorchInferenceRunner(InferenceRunner):
def __init__(self, device: torch.device):
self._device = device
- def run_inference(self, batch: List[torch.Tensor],
- model: torch.nn.Module) -> Iterable[PredictionResult]:
+ def run_inference(
+ self,
+ batch: List[Union[torch.Tensor, Dict[str, torch.Tensor]]],
+ model: torch.nn.Module,
+ prediction_params: Optional[Dict[str, Any]] = None,
+ ) -> Iterable[PredictionResult]:
"""
Runs inferences on a batch of Tensors and returns an Iterable of
Tensor Predictions.
This method stacks the list of Tensors in a vectorized format to optimize
the inference call.
"""
+ if prediction_params is None:
+ prediction_params = {}
- torch_batch = torch.stack(batch)
- if torch_batch.device != self._device:
- torch_batch = torch_batch.to(self._device)
- predictions = model(torch_batch)
+ if isinstance(batch[0], dict):
+ result_dict = defaultdict(list)
+ for el in batch:
+ for k, v in el.items():
+ result_dict[k].append(v)
+ for k in result_dict:
+ batched_values = torch.stack(result_dict[k])
+ if batched_values.device != self._device:
Review Comment:
The original intention was to verify that the examples, which may or may not
be attached to GPU during creation time, match the `device` param during
`PytorchModelLoader` creation. i.e. a user may pass in `device='GPU'`, and a
torch.Tensor(, device='cuda') example, but if GPU is not detected in the
environment, then we have to convert the example back to CPU.
Issue Time Tracking
-------------------
Worklog Id: (was: 772800)
Time Spent: 3h 10m (was: 3h)
> Support **kwargs for PyTorch models.
> ------------------------------------
>
> Key: BEAM-14337
> URL: https://issues.apache.org/jira/browse/BEAM-14337
> Project: Beam
> Issue Type: Sub-task
> Components: sdk-py-core
> Reporter: Anand Inguva
> Assignee: Andy Ye
> Priority: P2
> Time Spent: 3h 10m
> Remaining Estimate: 0h
>
> Some models in Pytorch instantiating from torch.nn.Module, has extra
> parameters in the forward function call. These extra parameters can be passed
> as Dict or as positional arguments.
> Example of PyTorch models supported by Hugging Face ->
> [https://huggingface.co/bert-base-uncased]
> [Some torch models on Hugging
> face|https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py]
> Eg:
> [https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel]
> {code:java}
> inputs = {
> input_ids: Tensor1,
> attention_mask: Tensor2,
> token_type_ids: Tensor3,
> }
> model = BertModel.from_pretrained("bert-base-uncased") # which is a
> # subclass of torch.nn.Module
> outputs = model(**inputs) # model forward method should be expecting the keys
> in the inputs as the positional arguments.{code}
>
> [Transformers|https://pytorch.org/hub/huggingface_pytorch-transformers/]
> integrated in Pytorch is supported by Hugging Face as well.
>
--
This message was sent by Atlassian Jira
(v8.20.7#820007)