You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2022/08/01 21:44:59 UTC

[GitHub] [beam] fabito commented on issue #22402: [Feature Request]: Ability to detect (and maybe reload ?) corrupted models when using the RunInference API

fabito commented on issue #22402:
URL: https://github.com/apache/beam/issues/22402#issuecomment-1201754823

   Hi @yeandy ,
   
   > Can you please provide the values for self.model_name, self.pretrained, self._device?
   
   Here is my custom implementation of `ModelHandler` (nearly a copy of the existing TorchModelHandler):
   
   ```python
   class OpenClipTorchModelHandlerTensor(ModelHandler[torch.Tensor, PredictionResult, torch.nn.Module]):
   
       def __init__(self, model_name: str = 'ViT-B-32-quickgelu', pretrained: str = 'laion400m_e32', device: str = 'CPU'):
           self.pretrained = pretrained
           self.model_name = model_name
           if device == 'GPU' and torch.cuda.is_available():
               self._device = torch.device('cuda')
           else:
               self._device = torch.device('cpu')
   
           if model_name in open_clip.factory._MODEL_CONFIGS:
               logging.info(f'Loading {model_name} model config.')
               self.model_cfg = deepcopy(open_clip.factory._MODEL_CONFIGS[model_name])
           else:
               raise ValueError('Invalid open clip model name')
   
       def load_model(self) -> torch.nn.Module:
           model = open_clip.create_model(self.model_name, pretrained=self.pretrained, device=self._device)
           return model.visual
   
       def run_inference(
               self,
               batch: Sequence[torch.Tensor],
               model: torch.nn.Module,
               inference_args: Optional[Dict[str, Any]] = None
       ) -> Iterable[PredictionResult]:
           batched_tensors = torch.stack(batch)
           batched_tensors = _convert_to_device(batched_tensors, self._device)
           with torch.no_grad():
               predictions = model(batched_tensors)
           return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
   
       def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
           """
           Returns:
               The number of bytes of data for a batch of Tensors.
           """
           return sum((el.element_size() for tensor in batch for el in tensor))
   
       def get_metrics_namespace(self) -> str:
           """
           Returns:
              A namespace for metrics collected by the RunInference transform.
           """
           return 'RunInferenceOpenClipTorch'
   
       def preprocess_transform(self) -> transforms.Compose:
           image_size = self.model_cfg['vision_cfg']['image_size']
           return open_clip.transform.image_transform(image_size, False)
   ```
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org