[
https://issues.apache.org/jira/browse/BEAM-14337?focusedWorklogId=769920&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-769920
]
ASF GitHub Bot logged work on BEAM-14337:
-----------------------------------------
Author: ASF GitHub Bot
Created on: 12/May/22 21:47
Start Date: 12/May/22 21:47
Worklog Time Spent: 10m
Work Description: robertwb commented on code in PR #17470:
URL: https://github.com/apache/beam/pull/17470#discussion_r871834743
##########
sdks/python/apache_beam/ml/inference/pytorch.py:
##########
@@ -45,11 +46,22 @@ def run_inference(self, batch: List[torch.Tensor],
This method stacks the list of Tensors in a vectorized format to optimize
the inference call.
"""
-
- batch = torch.stack(batch)
- if batch.device != self._device:
- batch = batch.to(self._device)
- predictions = model(batch)
+ if isinstance(batch[0], dict):
Review Comment:
+1. If PytorchInferenceRunner (and its associated ModelLoder) were typed, we
would have one that takes tensors and one that takes dicts of tensors, and this
typing would go all the way through to the PCollections being PCollections of
tensors vs. PCollections of dicts of tensors.
Issue Time Tracking
-------------------
Worklog Id: (was: 769920)
Time Spent: 1h 50m (was: 1h 40m)
> Support **kwargs for PyTorch models.
> ------------------------------------
>
> Key: BEAM-14337
> URL: https://issues.apache.org/jira/browse/BEAM-14337
> Project: Beam
> Issue Type: Sub-task
> Components: sdk-py-core
> Reporter: Anand Inguva
> Assignee: Andy Ye
> Priority: P2
> Time Spent: 1h 50m
> Remaining Estimate: 0h
>
> Some models in Pytorch instantiating from torch.nn.Module, has extra
> parameters in the forward function call. These extra parameters can be passed
> as Dict or as positional arguments.
> Example of PyTorch models supported by Hugging Face ->
> [https://huggingface.co/bert-base-uncased]
> [Some torch models on Hugging
> face|https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py]
> Eg:
> [https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel]
> {code:java}
> inputs = {
> input_ids: Tensor1,
> attention_mask: Tensor2,
> token_type_ids: Tensor3,
> }
> model = BertModel.from_pretrained("bert-base-uncased") # which is a
> # subclass of torch.nn.Module
> outputs = model(**inputs) # model forward method should be expecting the keys
> in the inputs as the positional arguments.{code}
>
> [Transformers|https://pytorch.org/hub/huggingface_pytorch-transformers/]
> integrated in Pytorch is supported by Hugging Face as well.
>
--
This message was sent by Atlassian Jira
(v8.20.7#820007)