[ 
https://issues.apache.org/jira/browse/BEAM-14337?focusedWorklogId=772840&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-772840
 ]

ASF GitHub Bot logged work on BEAM-14337:
-----------------------------------------

                Author: ASF GitHub Bot
            Created on: 20/May/22 13:41
            Start Date: 20/May/22 13:41
    Worklog Time Spent: 10m 
      Work Description: ryanthompson591 commented on code in PR #17470:
URL: https://github.com/apache/beam/pull/17470#discussion_r878143536


##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -39,7 +41,12 @@ class FakeInferenceRunner(base.InferenceRunner):
   def __init__(self, clock=None):
     self._mock_clock = clock
 
-  def run_inference(self, batch: Any, model: Any) -> Iterable[Any]:
+  def run_inference(
+      self,
+      batch: Any,
+      model: Any,
+      prediction_params: Optional[Dict[str, Any]] = None,

Review Comment:
   Heres some code to demonstrate what I mean.
   
   def pytorch_impl(prediction_params = None):
     print ('got prediction_params ' + str(prediction_params))
   
   def tf_impl():
     # Does not have the prediction_params arg, should not crash when called
     # but should crash if called with prediction_params argument
     print('tf_impl')
   
   mode = 'tf'
   
   def base_do_fn(**kwargs):
     if mode == 'tf':
       tf_impl(**kwargs)
     elif mode == 'pytorch':
       pytorch_impl(**kwargs)
   
   
   def base_run_inference(**kwargs):
     base_do_fn(**kwargs)
   
   base_run_inference()
   
   print('expect pass')
   mode='pytorch'
   base_run_inference(prediction_params={'prediction_param': 'value'})
   
   print('expect failure')
   mode='tf'
   base_run_inference(prediction_params={'prediction_param': 'value'})
   
   ----------------
   What I'm saying is this method allows forwarding of any named param 
whatsoever. With what you're doing we'll always be in a spot of modifying an 
interface and thus having to modify every implementation of the interface going 
forward for any single parameter that any single implementation wants.
   
   I think here, and going forward, unless named parameters will apply to all 
models we should not put it into the interface for all models.





Issue Time Tracking
-------------------

    Worklog Id:     (was: 772840)
    Time Spent: 4h 40m  (was: 4.5h)

> Support **kwargs for PyTorch models.
> ------------------------------------
>
>                 Key: BEAM-14337
>                 URL: https://issues.apache.org/jira/browse/BEAM-14337
>             Project: Beam
>          Issue Type: Sub-task
>          Components: sdk-py-core
>            Reporter: Anand Inguva
>            Assignee: Andy Ye
>            Priority: P2
>          Time Spent: 4h 40m
>  Remaining Estimate: 0h
>
> Some models in Pytorch instantiating from torch.nn.Module, has extra 
> parameters in the forward function call. These extra parameters can be passed 
> as Dict or as positional arguments. 
> Example of PyTorch models supported by Hugging Face -> 
> [https://huggingface.co/bert-base-uncased]
> [Some torch models on Hugging 
> face|https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py]
> Eg: 
> [https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel]
> {code:java}
> inputs = {
>      input_ids: Tensor1,
>      attention_mask: Tensor2,
>      token_type_ids: Tensor3,
> } 
> model = BertModel.from_pretrained("bert-base-uncased") # which is a  
> # subclass of torch.nn.Module
> outputs = model(**inputs) # model forward method should be expecting the keys 
> in the inputs as the positional arguments.{code}
>  
> [Transformers|https://pytorch.org/hub/huggingface_pytorch-transformers/] 
> integrated in Pytorch is supported by Hugging Face as well. 
>  



--
This message was sent by Atlassian Jira
(v8.20.7#820007)

Reply via email to