Anand Inguva created BEAM-14337:
-----------------------------------

             Summary: Support **kwargs for PyTorch models.
                 Key: BEAM-14337
                 URL: https://issues.apache.org/jira/browse/BEAM-14337
             Project: Beam
          Issue Type: Sub-task
          Components: sdk-py-core
            Reporter: Anand Inguva


Some models in Pytorch instantiating from torch.nn.Module, has extra parameters 
in the forward function call. These extra parameters can be passed as Dict or 
as positional arguments. 

Example of PyTorch models supported by Hugging Face -> 
https://huggingface.co/bert-base-uncased

[Some torch models on Hugging 
face|https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py]

Eg: 
https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel

inputs = {

     input_ids: Tensor1,

     attention_mask: Tensor2,

     token_type_ids: Tensor3,

}

model = BertModel.from_pretrained("bert-base-uncased") # which is a subclass of 
#
{code:java}
inputs = {
     input_ids: Tensor1,
     attention_mask: Tensor2,
     token_type_ids: Tensor3,
} 
model = BertModel.from_pretrained("bert-base-uncased") # which is a  #subclass 
of torch.nn.Module{code}
 torch.nn.Module. 

 

 



--
This message was sent by Atlassian Jira
(v8.20.7#820007)

Reply via email to