[ 
https://issues.apache.org/jira/browse/BEAM-14337?focusedWorklogId=772645&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-772645
 ]

ASF GitHub Bot logged work on BEAM-14337:
-----------------------------------------

                Author: ASF GitHub Bot
            Created on: 20/May/22 02:14
            Start Date: 20/May/22 02:14
    Worklog Time Spent: 10m 
      Work Description: ryanthompson591 commented on code in PR #17470:
URL: https://github.com/apache/beam/pull/17470#discussion_r877675380


##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -96,7 +107,9 @@ def expand(self, pcoll: beam.PCollection) -> 
beam.PCollection:
         pcoll
         # TODO(BEAM-14044): Hook into the batching DoFn APIs.
         | beam.BatchElements()
-        | beam.ParDo(_RunInferenceDoFn(self._model_loader, self._clock)))
+        | beam.ParDo(
+            _RunInferenceDoFn(
+                self._model_loader, self._prediction_params, self._clock)))

Review Comment:
   This is also what I preferred. But after a long discussion we settled on 
making these a side input.
   
   The reason was that since these static parameters were conceptually 
different because the were parameters for the inputs rather than model loading 
parameters, they should not be passed in on the model loader. If you want to 
rehash this I suggest setting up a meeting.
   





Issue Time Tracking
-------------------

    Worklog Id:     (was: 772645)
    Time Spent: 2h 20m  (was: 2h 10m)

> Support **kwargs for PyTorch models.
> ------------------------------------
>
>                 Key: BEAM-14337
>                 URL: https://issues.apache.org/jira/browse/BEAM-14337
>             Project: Beam
>          Issue Type: Sub-task
>          Components: sdk-py-core
>            Reporter: Anand Inguva
>            Assignee: Andy Ye
>            Priority: P2
>          Time Spent: 2h 20m
>  Remaining Estimate: 0h
>
> Some models in Pytorch instantiating from torch.nn.Module, has extra 
> parameters in the forward function call. These extra parameters can be passed 
> as Dict or as positional arguments. 
> Example of PyTorch models supported by Hugging Face -> 
> [https://huggingface.co/bert-base-uncased]
> [Some torch models on Hugging 
> face|https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py]
> Eg: 
> [https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel]
> {code:java}
> inputs = {
>      input_ids: Tensor1,
>      attention_mask: Tensor2,
>      token_type_ids: Tensor3,
> } 
> model = BertModel.from_pretrained("bert-base-uncased") # which is a  
> # subclass of torch.nn.Module
> outputs = model(**inputs) # model forward method should be expecting the keys 
> in the inputs as the positional arguments.{code}
>  
> [Transformers|https://pytorch.org/hub/huggingface_pytorch-transformers/] 
> integrated in Pytorch is supported by Hugging Face as well. 
>  



--
This message was sent by Atlassian Jira
(v8.20.7#820007)

Reply via email to