Anand Inguva created BEAM-14337:
-----------------------------------
Summary: Support **kwargs for PyTorch models.
Key: BEAM-14337
URL: https://issues.apache.org/jira/browse/BEAM-14337
Project: Beam
Issue Type: Sub-task
Components: sdk-py-core
Reporter: Anand Inguva
Some models in Pytorch instantiating from torch.nn.Module, has extra parameters
in the forward function call. These extra parameters can be passed as Dict or
as positional arguments.
Example of PyTorch models supported by Hugging Face ->
https://huggingface.co/bert-base-uncased
[Some torch models on Hugging
face|https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py]
Eg:
https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel
inputs = {
input_ids: Tensor1,
attention_mask: Tensor2,
token_type_ids: Tensor3,
}
model = BertModel.from_pretrained("bert-base-uncased") # which is a subclass of
#
{code:java}
inputs = {
input_ids: Tensor1,
attention_mask: Tensor2,
token_type_ids: Tensor3,
}
model = BertModel.from_pretrained("bert-base-uncased") # which is a #subclass
of torch.nn.Module{code}
torch.nn.Module.
--
This message was sent by Atlassian Jira
(v8.20.7#820007)