You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bh...@apache.org on 2022/06/14 15:51:11 UTC

[beam] branch master updated: Split PytorchModelHandler into PytorchModelHandlerTensor and PytorchModelHandlerKeyedTensor (#21810)

This is an automated email from the ASF dual-hosted git repository.

bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new a94feb146b7 Split PytorchModelHandler into PytorchModelHandlerTensor and PytorchModelHandlerKeyedTensor (#21810)
a94feb146b7 is described below

commit a94feb146b72e3d7c2589f5f26f2f699c9d2989e
Author: Andy Ye <an...@gmail.com>
AuthorDate: Tue Jun 14 11:51:00 2022 -0400

    Split PytorchModelHandler into PytorchModelHandlerTensor and PytorchModelHandlerKeyedTensor (#21810)
    
    * Start to split of Pytorch handlers
    
    * Remove old PytorchModelHandler
---
 .../apache_beam/ml/inference/pytorch_inference.py  | 162 ++++++++++++++-------
 .../ml/inference/pytorch_inference_test.py         |  23 ++-
 .../ml/inference/sklearn_inference_test.py         |   4 +-
 3 files changed, 130 insertions(+), 59 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index d8ab31b8b70..959bce4778e 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -23,7 +23,6 @@ from typing import Callable
 from typing import Dict
 from typing import Iterable
 from typing import Sequence
-from typing import Union
 
 import torch
 from apache_beam.io.filesystems import FileSystems
@@ -31,14 +30,32 @@ from apache_beam.ml.inference.base import ModelHandler
 from apache_beam.ml.inference.base import PredictionResult
 
 
-class PytorchModelHandler(ModelHandler[torch.Tensor,
-                                       PredictionResult,
-                                       torch.nn.Module]):
-  """ Implementation of the ModelHandler interface for PyTorch.
+def _load_model(
+    model_class: torch.nn.Module, state_dict_path, device, **model_params):
+  model = model_class(**model_params)
+  model.to(device)
+  file = FileSystems.open(state_dict_path, 'rb')
+  model.load_state_dict(torch.load(file))
+  model.eval()
+  return model
 
-      NOTE: This API and its implementation are under development and
-      do not provide backward compatibility guarantees.
+
+def _convert_to_device(examples: torch.Tensor, device) -> torch.Tensor:
   """
+  Converts samples to a style matching given device.
+
+  Note: A user may pass in device='GPU' but if GPU is not detected in the
+  environment it must be converted back to CPU.
+  """
+  if examples.device != device:
+    examples = examples.to(device)
+  return examples
+
+
+class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
+                                             PredictionResult,
+                                             torch.nn.Module]):
+  """ Implementation of the ModelHandler interface for PyTorch."""
   def __init__(
       self,
       state_dict_path: str,
@@ -46,7 +63,7 @@ class PytorchModelHandler(ModelHandler[torch.Tensor,
       model_params: Dict[str, Any],
       device: str = 'CPU'):
     """
-    Initializes a PytorchModelHandler
+    Initializes a PytorchModelHandlerTensor
     :param state_dict_path: path to the saved dictionary of the model state.
     :param model_class: class of the Pytorch model that defines the model
     structure.
@@ -67,67 +84,114 @@ class PytorchModelHandler(ModelHandler[torch.Tensor,
 
   def load_model(self) -> torch.nn.Module:
     """Loads and initializes a Pytorch model for processing."""
-    model = self._model_class(**self._model_params)
-    model.to(self._device)
-    file = FileSystems.open(self._state_dict_path, 'rb')
-    model.load_state_dict(torch.load(file))
-    model.eval()
-    return model
-
-  def _convert_to_device(self, examples: torch.Tensor) -> torch.Tensor:
+    return _load_model(
+        self._model_class,
+        self._state_dict_path,
+        self._device,
+        **self._model_params)
+
+  def run_inference(
+      self, batch: Sequence[torch.Tensor], model: torch.nn.Module,
+      **kwargs) -> Iterable[PredictionResult]:
+    """
+    Runs inferences on a batch of Tensors and returns an Iterable of
+    Tensor Predictions.
+
+    This method stacks the list of Tensors in a vectorized format to optimize
+    the inference call.
+    """
+    prediction_params = kwargs.get('prediction_params', {})
+    batched_tensors = torch.stack(batch)
+    batched_tensors = _convert_to_device(batched_tensors, self._device)
+    predictions = model(batched_tensors, **prediction_params)
+    return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+  def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
+    """Returns the number of bytes of data for a batch of Tensors."""
+    return sum((el.element_size() for tensor in batch for el in tensor))
+
+  def get_metrics_namespace(self) -> str:
+    """
+    Returns a namespace for metrics collected by the RunInference transform.
+    """
+    return 'RunInferencePytorch'
+
+
+class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
+                                                  PredictionResult,
+                                                  torch.nn.Module]):
+  """ Implementation of the ModelHandler interface for PyTorch.
+
+      NOTE: This API and its implementation are under development and
+      do not provide backward compatibility guarantees.
+  """
+  def __init__(
+      self,
+      state_dict_path: str,
+      model_class: Callable[..., torch.nn.Module],
+      model_params: Dict[str, Any],
+      device: str = 'CPU'):
     """
-    Converts samples to a style matching given device.
+    Initializes a PytorchModelHandlerKeyedTensor
+    :param state_dict_path: path to the saved dictionary of the model state.
+    :param model_class: class of the Pytorch model that defines the model
+    structure.
+    :param device: the device on which you wish to run the model. If
+    ``device = GPU`` then a GPU device will be used if it is available.
+    Otherwise, it will be CPU.
 
-    Note: A user may pass in device='GPU' but if GPU is not detected in the
-    environment it must be converted back to CPU.
+    See https://pytorch.org/tutorials/beginner/saving_loading_models.html
+    for details
     """
-    if examples.device != self._device:
-      examples = examples.to(self._device)
-    return examples
+    self._state_dict_path = state_dict_path
+    if device == 'GPU' and torch.cuda.is_available():
+      self._device = torch.device('cuda')
+    else:
+      self._device = torch.device('cpu')
+    self._model_class = model_class
+    self._model_params = model_params
+
+  def load_model(self) -> torch.nn.Module:
+    """Loads and initializes a Pytorch model for processing."""
+    return _load_model(
+        self._model_class,
+        self._state_dict_path,
+        self._device,
+        **self._model_params)
 
   def run_inference(
       self,
-      batch: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]],
+      batch: Sequence[Dict[str, torch.Tensor]],
       model: torch.nn.Module,
       **kwargs) -> Iterable[PredictionResult]:
     """
-    Runs inferences on a batch of Tensors and returns an Iterable of
+    Runs inferences on a batch of Keyed Tensors and returns an Iterable of
     Tensor Predictions.
 
-    This method stacks the list of Tensors in a vectorized format to optimize
-    the inference call.
+    For the same key across all examples, this will stack all Tensors values
+    in a vectorized format to optimize the inference call.
     """
     prediction_params = kwargs.get('prediction_params', {})
 
     # If elements in `batch` are provided as a dictionaries from key to Tensors,
     # then iterate through the batch list, and group Tensors to the same key
-    if isinstance(batch[0], dict):
-      key_to_tensor_list = defaultdict(list)
-      for example in batch:
-        for key, tensor in example.items():
-          key_to_tensor_list[key].append(tensor)
-      key_to_batched_tensors = {}
-      for key in key_to_tensor_list:
-        batched_tensors = torch.stack(key_to_tensor_list[key])
-        batched_tensors = self._convert_to_device(batched_tensors)
-        key_to_batched_tensors[key] = batched_tensors
-      predictions = model(**key_to_batched_tensors, **prediction_params)
-    else:
-      # If elements in `batch` are provided as Tensors, then do a regular stack
-      batched_tensors = torch.stack(batch)
-      batched_tensors = self._convert_to_device(batched_tensors)
-      predictions = model(batched_tensors, **prediction_params)
+    key_to_tensor_list = defaultdict(list)
+    for example in batch:
+      for key, tensor in example.items():
+        key_to_tensor_list[key].append(tensor)
+    key_to_batched_tensors = {}
+    for key in key_to_tensor_list:
+      batched_tensors = torch.stack(key_to_tensor_list[key])
+      batched_tensors = _convert_to_device(batched_tensors, self._device)
+      key_to_batched_tensors[key] = batched_tensors
+    predictions = model(**key_to_batched_tensors, **prediction_params)
     return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
 
   def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
-    """Returns the number of bytes of data for a batch of Tensors."""
+    """Returns the number of bytes of data for a batch of Dict of Tensors."""
     # If elements in `batch` are provided as a dictionaries from key to Tensors
-    if isinstance(batch[0], dict):
-      return sum(
-          (el.element_size() for tensor in batch for el in tensor.values()))
-    else:
-      # If elements in `batch` are provided as Tensors
-      return sum((el.element_size() for tensor in batch for el in tensor))
+    return sum(
+        (el.element_size() for tensor in batch for el in tensor.values()))
 
   def get_metrics_namespace(self) -> str:
     """
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
index ad51a4e77f7..d852dd72bb7 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
@@ -37,7 +37,8 @@ try:
   import torch
   from apache_beam.ml.inference.base import PredictionResult
   from apache_beam.ml.inference.base import RunInference
-  from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler
+  from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
+  from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor
 except ImportError:
   raise unittest.SkipTest('PyTorch dependencies are not installed')
 
@@ -90,7 +91,13 @@ KWARGS_TORCH_PREDICTIONS = [
 ]
 
 
-class TestPytorchModelHandlerForInferenceOnly(PytorchModelHandler):
+class TestPytorchModelHandlerForInferenceOnly(PytorchModelHandlerTensor):
+  def __init__(self, device):
+    self._device = device
+
+
+class TestPytorchModelHandlerKeyedTensorForInferenceOnly(
+    PytorchModelHandlerKeyedTensor):
   def __init__(self, device):
     self._device = device
 
@@ -209,7 +216,7 @@ class PytorchRunInferenceTest(unittest.TestCase):
                      ('linear.bias', torch.Tensor([0.5]))]))
     model.eval()
 
-    inference_runner = TestPytorchModelHandlerForInferenceOnly(
+    inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly(
         torch.device('cpu'))
     predictions = inference_runner.run_inference(KWARGS_TORCH_EXAMPLES, model)
     for actual, expected in zip(predictions, KWARGS_TORCH_PREDICTIONS):
@@ -234,7 +241,7 @@ class PytorchRunInferenceTest(unittest.TestCase):
                      ('linear.bias', torch.Tensor([0.5]))]))
     model.eval()
 
-    inference_runner = TestPytorchModelHandlerForInferenceOnly(
+    inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly(
         torch.device('cpu'))
     predictions = inference_runner.run_inference(
         batch=KWARGS_TORCH_EXAMPLES,
@@ -274,7 +281,7 @@ class PytorchRunInferencePipelineTest(unittest.TestCase):
       path = os.path.join(self.tmpdir, 'my_state_dict_path')
       torch.save(state_dict, path)
 
-      model_handler = PytorchModelHandler(
+      model_handler = PytorchModelHandlerTensor(
           state_dict_path=path,
           model_class=PytorchLinearRegression,
           model_params={
@@ -301,7 +308,7 @@ class PytorchRunInferencePipelineTest(unittest.TestCase):
       path = os.path.join(self.tmpdir, 'my_state_dict_path')
       torch.save(state_dict, path)
 
-      model_handler = PytorchModelHandler(
+      model_handler = PytorchModelHandlerKeyedTensor(
           state_dict_path=path,
           model_class=PytorchLinearRegressionKwargsPredictionParams,
           model_params={
@@ -334,7 +341,7 @@ class PytorchRunInferencePipelineTest(unittest.TestCase):
 
       gs_pth = 'gs://apache-beam-ml/models/' \
           'pytorch_lin_reg_model_2x+0.5_state_dict.pth'
-      model_handler = PytorchModelHandler(
+      model_handler = PytorchModelHandlerTensor(
           state_dict_path=gs_pth,
           model_class=PytorchLinearRegression,
           model_params={
@@ -357,7 +364,7 @@ class PytorchRunInferencePipelineTest(unittest.TestCase):
         path = os.path.join(self.tmpdir, 'my_state_dict_path')
         torch.save(state_dict, path)
 
-        model_handler = PytorchModelHandler(
+        model_handler = PytorchModelHandlerTensor(
             state_dict_path=path,
             model_class=PytorchLinearRegression,
             model_params={
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
index 2c63de25f99..ecd81d204d6 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
@@ -225,9 +225,9 @@ class SkLearnRunInferenceTest(unittest.TestCase):
     with self.assertRaisesRegex(AssertionError,
                                 'Unsupported serialization type'):
       with tempfile.NamedTemporaryFile() as file:
-        model_loader = SklearnModelHandlerNumpy(
+        model_handler = SklearnModelHandlerNumpy(
             model_uri=file.name, model_file_type=None)
-        model_loader.load_model()
+        model_handler.load_model()
 
   @unittest.skipIf(platform.system() == 'Windows', 'BEAM-14359')
   def test_pipeline_pandas(self):