You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@beam.apache.org by "Anand Inguva (Jira)" <ji...@apache.org> on 2022/04/20 16:58:00 UTC
[jira] [Created] (BEAM-14337) Support **kwargs for PyTorch models.
Anand Inguva created BEAM-14337:
-----------------------------------
Summary: Support **kwargs for PyTorch models.
Key: BEAM-14337
URL: https://issues.apache.org/jira/browse/BEAM-14337
Project: Beam
Issue Type: Sub-task
Components: sdk-py-core
Reporter: Anand Inguva
Some models in Pytorch instantiating from torch.nn.Module, has extra parameters in the forward function call. These extra parameters can be passed as Dict or as positional arguments.
Example of PyTorch models supported by Hugging Face -> https://huggingface.co/bert-base-uncased
[Some torch models on Hugging face|https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py]
Eg: https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel
inputs = {
input_ids: Tensor1,
attention_mask: Tensor2,
token_type_ids: Tensor3,
}
model = BertModel.from_pretrained("bert-base-uncased") # which is a subclass of #
{code:java}
inputs = {
input_ids: Tensor1,
attention_mask: Tensor2,
token_type_ids: Tensor3,
}
model = BertModel.from_pretrained("bert-base-uncased") # which is a #subclass of torch.nn.Module{code}
torch.nn.Module.
--
This message was sent by Atlassian Jira
(v8.20.7#820007)