You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@beam.apache.org by "Anand Inguva (Jira)" <ji...@apache.org> on 2022/04/20 16:58:00 UTC

[jira] [Created] (BEAM-14337) Support **kwargs for PyTorch models.

Anand Inguva created BEAM-14337:
-----------------------------------

             Summary: Support **kwargs for PyTorch models.
                 Key: BEAM-14337
                 URL: https://issues.apache.org/jira/browse/BEAM-14337
             Project: Beam
          Issue Type: Sub-task
          Components: sdk-py-core
            Reporter: Anand Inguva


Some models in Pytorch instantiating from torch.nn.Module, has extra parameters in the forward function call. These extra parameters can be passed as Dict or as positional arguments. 

Example of PyTorch models supported by Hugging Face -> https://huggingface.co/bert-base-uncased

[Some torch models on Hugging face|https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py]

Eg: https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel

inputs = {

     input_ids: Tensor1,

     attention_mask: Tensor2,

     token_type_ids: Tensor3,

}

model = BertModel.from_pretrained("bert-base-uncased") # which is a subclass of #
{code:java}
inputs = {
     input_ids: Tensor1,
     attention_mask: Tensor2,
     token_type_ids: Tensor3,
} 
model = BertModel.from_pretrained("bert-base-uncased") # which is a  #subclass of torch.nn.Module{code}
 torch.nn.Module. 

 

 



--
This message was sent by Atlassian Jira
(v8.20.7#820007)