You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bh...@apache.org on 2022/06/14 15:51:11 UTC
[beam] branch master updated: Split PytorchModelHandler into PytorchModelHandlerTensor and PytorchModelHandlerKeyedTensor (#21810)
This is an automated email from the ASF dual-hosted git repository.
bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new a94feb146b7 Split PytorchModelHandler into PytorchModelHandlerTensor and PytorchModelHandlerKeyedTensor (#21810)
a94feb146b7 is described below
commit a94feb146b72e3d7c2589f5f26f2f699c9d2989e
Author: Andy Ye <an...@gmail.com>
AuthorDate: Tue Jun 14 11:51:00 2022 -0400
Split PytorchModelHandler into PytorchModelHandlerTensor and PytorchModelHandlerKeyedTensor (#21810)
* Start to split of Pytorch handlers
* Remove old PytorchModelHandler
---
.../apache_beam/ml/inference/pytorch_inference.py | 162 ++++++++++++++-------
.../ml/inference/pytorch_inference_test.py | 23 ++-
.../ml/inference/sklearn_inference_test.py | 4 +-
3 files changed, 130 insertions(+), 59 deletions(-)
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index d8ab31b8b70..959bce4778e 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -23,7 +23,6 @@ from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Sequence
-from typing import Union
import torch
from apache_beam.io.filesystems import FileSystems
@@ -31,14 +30,32 @@ from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
-class PytorchModelHandler(ModelHandler[torch.Tensor,
- PredictionResult,
- torch.nn.Module]):
- """ Implementation of the ModelHandler interface for PyTorch.
+def _load_model(
+ model_class: torch.nn.Module, state_dict_path, device, **model_params):
+ model = model_class(**model_params)
+ model.to(device)
+ file = FileSystems.open(state_dict_path, 'rb')
+ model.load_state_dict(torch.load(file))
+ model.eval()
+ return model
- NOTE: This API and its implementation are under development and
- do not provide backward compatibility guarantees.
+
+def _convert_to_device(examples: torch.Tensor, device) -> torch.Tensor:
"""
+ Converts samples to a style matching given device.
+
+ Note: A user may pass in device='GPU' but if GPU is not detected in the
+ environment it must be converted back to CPU.
+ """
+ if examples.device != device:
+ examples = examples.to(device)
+ return examples
+
+
+class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
+ PredictionResult,
+ torch.nn.Module]):
+ """ Implementation of the ModelHandler interface for PyTorch."""
def __init__(
self,
state_dict_path: str,
@@ -46,7 +63,7 @@ class PytorchModelHandler(ModelHandler[torch.Tensor,
model_params: Dict[str, Any],
device: str = 'CPU'):
"""
- Initializes a PytorchModelHandler
+ Initializes a PytorchModelHandlerTensor
:param state_dict_path: path to the saved dictionary of the model state.
:param model_class: class of the Pytorch model that defines the model
structure.
@@ -67,67 +84,114 @@ class PytorchModelHandler(ModelHandler[torch.Tensor,
def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
- model = self._model_class(**self._model_params)
- model.to(self._device)
- file = FileSystems.open(self._state_dict_path, 'rb')
- model.load_state_dict(torch.load(file))
- model.eval()
- return model
-
- def _convert_to_device(self, examples: torch.Tensor) -> torch.Tensor:
+ return _load_model(
+ self._model_class,
+ self._state_dict_path,
+ self._device,
+ **self._model_params)
+
+ def run_inference(
+ self, batch: Sequence[torch.Tensor], model: torch.nn.Module,
+ **kwargs) -> Iterable[PredictionResult]:
+ """
+ Runs inferences on a batch of Tensors and returns an Iterable of
+ Tensor Predictions.
+
+ This method stacks the list of Tensors in a vectorized format to optimize
+ the inference call.
+ """
+ prediction_params = kwargs.get('prediction_params', {})
+ batched_tensors = torch.stack(batch)
+ batched_tensors = _convert_to_device(batched_tensors, self._device)
+ predictions = model(batched_tensors, **prediction_params)
+ return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+ def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
+ """Returns the number of bytes of data for a batch of Tensors."""
+ return sum((el.element_size() for tensor in batch for el in tensor))
+
+ def get_metrics_namespace(self) -> str:
+ """
+ Returns a namespace for metrics collected by the RunInference transform.
+ """
+ return 'RunInferencePytorch'
+
+
+class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
+ PredictionResult,
+ torch.nn.Module]):
+ """ Implementation of the ModelHandler interface for PyTorch.
+
+ NOTE: This API and its implementation are under development and
+ do not provide backward compatibility guarantees.
+ """
+ def __init__(
+ self,
+ state_dict_path: str,
+ model_class: Callable[..., torch.nn.Module],
+ model_params: Dict[str, Any],
+ device: str = 'CPU'):
"""
- Converts samples to a style matching given device.
+ Initializes a PytorchModelHandlerKeyedTensor
+ :param state_dict_path: path to the saved dictionary of the model state.
+ :param model_class: class of the Pytorch model that defines the model
+ structure.
+ :param device: the device on which you wish to run the model. If
+ ``device = GPU`` then a GPU device will be used if it is available.
+ Otherwise, it will be CPU.
- Note: A user may pass in device='GPU' but if GPU is not detected in the
- environment it must be converted back to CPU.
+ See https://pytorch.org/tutorials/beginner/saving_loading_models.html
+ for details
"""
- if examples.device != self._device:
- examples = examples.to(self._device)
- return examples
+ self._state_dict_path = state_dict_path
+ if device == 'GPU' and torch.cuda.is_available():
+ self._device = torch.device('cuda')
+ else:
+ self._device = torch.device('cpu')
+ self._model_class = model_class
+ self._model_params = model_params
+
+ def load_model(self) -> torch.nn.Module:
+ """Loads and initializes a Pytorch model for processing."""
+ return _load_model(
+ self._model_class,
+ self._state_dict_path,
+ self._device,
+ **self._model_params)
def run_inference(
self,
- batch: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]],
+ batch: Sequence[Dict[str, torch.Tensor]],
model: torch.nn.Module,
**kwargs) -> Iterable[PredictionResult]:
"""
- Runs inferences on a batch of Tensors and returns an Iterable of
+ Runs inferences on a batch of Keyed Tensors and returns an Iterable of
Tensor Predictions.
- This method stacks the list of Tensors in a vectorized format to optimize
- the inference call.
+ For the same key across all examples, this will stack all Tensors values
+ in a vectorized format to optimize the inference call.
"""
prediction_params = kwargs.get('prediction_params', {})
# If elements in `batch` are provided as a dictionaries from key to Tensors,
# then iterate through the batch list, and group Tensors to the same key
- if isinstance(batch[0], dict):
- key_to_tensor_list = defaultdict(list)
- for example in batch:
- for key, tensor in example.items():
- key_to_tensor_list[key].append(tensor)
- key_to_batched_tensors = {}
- for key in key_to_tensor_list:
- batched_tensors = torch.stack(key_to_tensor_list[key])
- batched_tensors = self._convert_to_device(batched_tensors)
- key_to_batched_tensors[key] = batched_tensors
- predictions = model(**key_to_batched_tensors, **prediction_params)
- else:
- # If elements in `batch` are provided as Tensors, then do a regular stack
- batched_tensors = torch.stack(batch)
- batched_tensors = self._convert_to_device(batched_tensors)
- predictions = model(batched_tensors, **prediction_params)
+ key_to_tensor_list = defaultdict(list)
+ for example in batch:
+ for key, tensor in example.items():
+ key_to_tensor_list[key].append(tensor)
+ key_to_batched_tensors = {}
+ for key in key_to_tensor_list:
+ batched_tensors = torch.stack(key_to_tensor_list[key])
+ batched_tensors = _convert_to_device(batched_tensors, self._device)
+ key_to_batched_tensors[key] = batched_tensors
+ predictions = model(**key_to_batched_tensors, **prediction_params)
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
- """Returns the number of bytes of data for a batch of Tensors."""
+ """Returns the number of bytes of data for a batch of Dict of Tensors."""
# If elements in `batch` are provided as a dictionaries from key to Tensors
- if isinstance(batch[0], dict):
- return sum(
- (el.element_size() for tensor in batch for el in tensor.values()))
- else:
- # If elements in `batch` are provided as Tensors
- return sum((el.element_size() for tensor in batch for el in tensor))
+ return sum(
+ (el.element_size() for tensor in batch for el in tensor.values()))
def get_metrics_namespace(self) -> str:
"""
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
index ad51a4e77f7..d852dd72bb7 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
@@ -37,7 +37,8 @@ try:
import torch
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
- from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler
+ from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
+ from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor
except ImportError:
raise unittest.SkipTest('PyTorch dependencies are not installed')
@@ -90,7 +91,13 @@ KWARGS_TORCH_PREDICTIONS = [
]
-class TestPytorchModelHandlerForInferenceOnly(PytorchModelHandler):
+class TestPytorchModelHandlerForInferenceOnly(PytorchModelHandlerTensor):
+ def __init__(self, device):
+ self._device = device
+
+
+class TestPytorchModelHandlerKeyedTensorForInferenceOnly(
+ PytorchModelHandlerKeyedTensor):
def __init__(self, device):
self._device = device
@@ -209,7 +216,7 @@ class PytorchRunInferenceTest(unittest.TestCase):
('linear.bias', torch.Tensor([0.5]))]))
model.eval()
- inference_runner = TestPytorchModelHandlerForInferenceOnly(
+ inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly(
torch.device('cpu'))
predictions = inference_runner.run_inference(KWARGS_TORCH_EXAMPLES, model)
for actual, expected in zip(predictions, KWARGS_TORCH_PREDICTIONS):
@@ -234,7 +241,7 @@ class PytorchRunInferenceTest(unittest.TestCase):
('linear.bias', torch.Tensor([0.5]))]))
model.eval()
- inference_runner = TestPytorchModelHandlerForInferenceOnly(
+ inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly(
torch.device('cpu'))
predictions = inference_runner.run_inference(
batch=KWARGS_TORCH_EXAMPLES,
@@ -274,7 +281,7 @@ class PytorchRunInferencePipelineTest(unittest.TestCase):
path = os.path.join(self.tmpdir, 'my_state_dict_path')
torch.save(state_dict, path)
- model_handler = PytorchModelHandler(
+ model_handler = PytorchModelHandlerTensor(
state_dict_path=path,
model_class=PytorchLinearRegression,
model_params={
@@ -301,7 +308,7 @@ class PytorchRunInferencePipelineTest(unittest.TestCase):
path = os.path.join(self.tmpdir, 'my_state_dict_path')
torch.save(state_dict, path)
- model_handler = PytorchModelHandler(
+ model_handler = PytorchModelHandlerKeyedTensor(
state_dict_path=path,
model_class=PytorchLinearRegressionKwargsPredictionParams,
model_params={
@@ -334,7 +341,7 @@ class PytorchRunInferencePipelineTest(unittest.TestCase):
gs_pth = 'gs://apache-beam-ml/models/' \
'pytorch_lin_reg_model_2x+0.5_state_dict.pth'
- model_handler = PytorchModelHandler(
+ model_handler = PytorchModelHandlerTensor(
state_dict_path=gs_pth,
model_class=PytorchLinearRegression,
model_params={
@@ -357,7 +364,7 @@ class PytorchRunInferencePipelineTest(unittest.TestCase):
path = os.path.join(self.tmpdir, 'my_state_dict_path')
torch.save(state_dict, path)
- model_handler = PytorchModelHandler(
+ model_handler = PytorchModelHandlerTensor(
state_dict_path=path,
model_class=PytorchLinearRegression,
model_params={
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
index 2c63de25f99..ecd81d204d6 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
@@ -225,9 +225,9 @@ class SkLearnRunInferenceTest(unittest.TestCase):
with self.assertRaisesRegex(AssertionError,
'Unsupported serialization type'):
with tempfile.NamedTemporaryFile() as file:
- model_loader = SklearnModelHandlerNumpy(
+ model_handler = SklearnModelHandlerNumpy(
model_uri=file.name, model_file_type=None)
- model_loader.load_model()
+ model_handler.load_model()
@unittest.skipIf(platform.system() == 'Windows', 'BEAM-14359')
def test_pipeline_pandas(self):