You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2022/08/05 07:36:14 UTC

[GitHub] [beam] agvdndor commented on issue #22572: [Feature Request]: Allow specification of a custom model inference method for a RunInference ModelHandler

agvdndor commented on issue #22572:
URL: https://github.com/apache/beam/issues/22572#issuecomment-1206141493

   I could imagine three options:
   1. Stick to the current contract and assume that users will subclass the existing handlers to accommodate their model when it falls outside of the contract.
   2. Create a separate GenerationModelHandler.  I'm not a fan of this approach. As @yeandy commented, there's a lot of fairly common options out there: `predict_proba`, `apply`, `encode`, `decode`, `generate`... So this might not scale too well and lead to a proliferation of model handlers
   3. Let the user pass the model_inference_fn during initialization as an optional kwarg.
   
   Personally, I'd prefer option three. Something like this:
   
   ```
   from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, DistilBertConfig
   from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor
   
   model_handler = PytorchModelHandlerTensor(
   class PytorchModelHandlerTensor(
       state_dict_path="<path-to-state-dict-file>",
       model_class=DistilBertForSequenceClassification,
       model_params={"config": DistilBertConfig("<path-to-config-file>"},
       model_inference_fn=DistilBertForSequenceClassification.generate)
   ```
   
   Wyt?
    


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org