[ 
https://issues.apache.org/jira/browse/BEAM-14337?focusedWorklogId=772882&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-772882
 ]

ASF GitHub Bot logged work on BEAM-14337:
-----------------------------------------

                Author: ASF GitHub Bot
            Created on: 20/May/22 15:06
            Start Date: 20/May/22 15:06
    Worklog Time Spent: 10m 
      Work Description: yeandy commented on code in PR #17470:
URL: https://github.com/apache/beam/pull/17470#discussion_r878255796


##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -96,7 +107,9 @@ def expand(self, pcoll: beam.PCollection) -> 
beam.PCollection:
         pcoll
         # TODO(BEAM-14044): Hook into the batching DoFn APIs.
         | beam.BatchElements()
-        | beam.ParDo(_RunInferenceDoFn(self._model_loader, self._clock)))
+        | beam.ParDo(
+            _RunInferenceDoFn(
+                self._model_loader, self._prediction_params, self._clock)))

Review Comment:
   Note: The code has been changed slightly to `**kwargs` instead of 
`self._prediction_params`.
   
   Passing in these extra prediction-related parameters as side inputs makes 
more conceptual sense because both the `batch` examples and `prediction_params` 
kwargs are different types of inputs that are both passed at the same time into 
Pytorch model's predict call (which are both passed into its `forward` function)
   ```
   # Inside PytorchInferenceRunner()
   def run_inference(batch, model, **kwargs):
      prediction_params = kwargs.get('prediction_params')
      model(batch, **prediction_params)
   
   # this call goes inside Pytorch model ...
   class PytorchModel(nn.Module):
       def forward(self, batch, **prediction_params):
           ...
   ```





Issue Time Tracking
-------------------

    Worklog Id:     (was: 772882)
    Time Spent: 5.5h  (was: 5h 20m)

> Support **kwargs for PyTorch models.
> ------------------------------------
>
>                 Key: BEAM-14337
>                 URL: https://issues.apache.org/jira/browse/BEAM-14337
>             Project: Beam
>          Issue Type: Sub-task
>          Components: sdk-py-core
>            Reporter: Anand Inguva
>            Assignee: Andy Ye
>            Priority: P2
>          Time Spent: 5.5h
>  Remaining Estimate: 0h
>
> Some models in Pytorch instantiating from torch.nn.Module, has extra 
> parameters in the forward function call. These extra parameters can be passed 
> as Dict or as positional arguments. 
> Example of PyTorch models supported by Hugging Face -> 
> [https://huggingface.co/bert-base-uncased]
> [Some torch models on Hugging 
> face|https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py]
> Eg: 
> [https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel]
> {code:java}
> inputs = {
>      input_ids: Tensor1,
>      attention_mask: Tensor2,
>      token_type_ids: Tensor3,
> } 
> model = BertModel.from_pretrained("bert-base-uncased") # which is a  
> # subclass of torch.nn.Module
> outputs = model(**inputs) # model forward method should be expecting the keys 
> in the inputs as the positional arguments.{code}
>  
> [Transformers|https://pytorch.org/hub/huggingface_pytorch-transformers/] 
> integrated in Pytorch is supported by Hugging Face as well. 
>  



--
This message was sent by Atlassian Jira
(v8.20.7#820007)

Reply via email to