You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2022/04/05 22:17:39 UTC

[GitHub] [beam] ryanthompson591 commented on a diff in pull request #17196: [BEAM-13984] Implement RunInference for PyTorch

ryanthompson591 commented on code in PR #17196:
URL: https://github.com/apache/beam/pull/17196#discussion_r843170687


##########
sdks/python/apache_beam/ml/inference/pytorch_impl.py:
##########
@@ -0,0 +1,119 @@
+#

Review Comment:
   Let's rename this to something other than impl.
   
   This will actually be part of the API now.  Some suggestions:
   pytorch
   pytorch_model_loader
   pytorch_inference
   
   When it is called it might look like this.
   
   inference_runner = beam.ml.inference.pytorch.PyTorchInferenceRunner()
   beam.ml.inference.RunInference(inference_runner)



##########
sdks/python/setup.py:
##########
@@ -234,6 +234,10 @@ def get_version():
     'azure-core >=1.7.0',
 ]
 
+ML_REQUIREMENTS = [
+    'torch >= 1.10.2'

Review Comment:
   Does this mean all python sdks require pytorch going forward.  Is this a heavy requirement? 
   
   If so let's make sure it's fine to require this for everyone.



##########
sdks/python/apache_beam/ml/inference/pytorch_impl.py:
##########
@@ -0,0 +1,119 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+from typing import Iterable
+from typing import List
+from typing import Union
+
+import numpy as np
+import pickle
+import torch
+from torch import nn
+
+from apache_beam.ml.inference.base import InferenceRunner
+from apache_beam.ml.inference.base import ModelLoader
+
+
+class PytorchInferenceRunner(InferenceRunner):
+  """
+  Implements Pytorch inference method
+  """
+  def __init__(self, input_dim: int, device: torch.device):
+    self._input_dim = input_dim
+    self._device = device
+
+  def run_inference(
+      self, batch: List[Union[np.ndarray, torch.Tensor]],
+      model: nn.Module) -> Iterable[torch.Tensor]:
+    """
+    Runs inferences on a batch of examples and returns an Iterable of
+    Predictions."""
+    if batch:
+      if isinstance(batch[0], np.ndarray):
+        batch = torch.Tensor(batch)
+      elif isinstance(batch[0], torch.Tensor):
+        batch = torch.stack(batch)
+      else:
+        raise ValueError("PCollection must be an numpy array or a torch Tensor")
+
+    if batch.device != self._device:
+      batch = batch.to(self._device)
+    return model(batch)
+
+  def get_num_bytes(self, batch: List[torch.Tensor]) -> int:
+    """Returns the number of bytes of data for a batch."""
+    total_size = 0
+    for el in batch:
+      if isinstance(el, np.ndarray):
+        total_size += el.itemsize
+      elif isinstance(el, torch.Tensor):
+        total_size += el.element_size()
+      else:
+        total_size += len(pickle.dumps(el))

Review Comment:
   this would be an unsupported type in this case right?



##########
sdks/python/apache_beam/ml/inference/pytorch_impl.py:
##########
@@ -0,0 +1,119 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+from typing import Iterable
+from typing import List
+from typing import Union
+
+import numpy as np
+import pickle
+import torch
+from torch import nn
+
+from apache_beam.ml.inference.base import InferenceRunner
+from apache_beam.ml.inference.base import ModelLoader
+
+
+class PytorchInferenceRunner(InferenceRunner):
+  """
+  Implements Pytorch inference method
+  """
+  def __init__(self, input_dim: int, device: torch.device):
+    self._input_dim = input_dim
+    self._device = device
+
+  def run_inference(
+      self, batch: List[Union[np.ndarray, torch.Tensor]],
+      model: nn.Module) -> Iterable[torch.Tensor]:
+    """
+    Runs inferences on a batch of examples and returns an Iterable of
+    Predictions."""
+    if batch:
+      if isinstance(batch[0], np.ndarray):
+        batch = torch.Tensor(batch)
+      elif isinstance(batch[0], torch.Tensor):
+        batch = torch.stack(batch)
+      else:
+        raise ValueError("PCollection must be an numpy array or a torch Tensor")
+
+    if batch.device != self._device:
+      batch = batch.to(self._device)
+    return model(batch)
+
+  def get_num_bytes(self, batch: List[torch.Tensor]) -> int:
+    """Returns the number of bytes of data for a batch."""
+    total_size = 0
+    for el in batch:
+      if isinstance(el, np.ndarray):
+        total_size += el.itemsize
+      elif isinstance(el, torch.Tensor):
+        total_size += el.element_size()
+      else:
+        total_size += len(pickle.dumps(el))
+    return total_size
+
+  def get_metrics_namespace(self) -> str:
+    """
+    Returns a namespace for metrics collected by the RunInference transform.
+    """
+    return 'RunInferencePytorch'
+
+
+class PytorchModelLoader(ModelLoader):
+  """Loads a Pytorch Model."""
+  def __init__(
+      self,
+      input_dim: int,

Review Comment:
   Will this be needed if the input is a tensor?  I'm trying to figure out where this number is even used.



##########
sdks/python/apache_beam/ml/inference/pytorch_impl_test.py:
##########
@@ -0,0 +1,221 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+import os
+import shutil
+import tempfile
+import unittest
+from collections import OrderedDict
+
+import numpy as np
+import pytest
+import torch
+
+import apache_beam as beam
+from apache_beam.ml.inference import base
+from apache_beam.ml.inference.pytorch_impl import PytorchModelLoader
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+
+
+class PytorchLinearRegression(torch.nn.Module):
+  def __init__(self, inputSize, outputSize):
+    super().__init__()
+    self.linear = torch.nn.Linear(inputSize, outputSize)
+
+  def forward(self, x):
+    out = self.linear(x)
+    return out
+
+
+class PytorchRunInferenceTest(unittest.TestCase):
+  def setUp(self):
+    self.tmpdir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    shutil.rmtree(self.tmpdir)
+
+  def test_simple_single_tensor_feature(self):
+    with TestPipeline() as pipeline:
+      examples = torch.from_numpy(
+          np.array([1, 5, 3, 10], dtype="float32").reshape(-1, 1))
+      expected = torch.Tensor([example * 2.0 + 0.5 for example in examples])
+
+      state_dict = OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
+                                ('linear.bias', torch.Tensor([0.5]))])
+      path = os.path.join(self.tmpdir, 'my_state_dict_path')
+      torch.save(state_dict, path)
+
+      input_dim = 1
+      output_dim = 1
+
+      model_loader = PytorchModelLoader(
+          input_dim=input_dim,
+          state_dict_path=path,
+          model_class=PytorchLinearRegression(input_dim, output_dim))
+
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      actual = pcoll | base.RunInference(model_loader)
+      assert_that(actual, equal_to(expected))
+
+  def test_invalid_input_type(self):
+    with self.assertRaisesRegex(
+        ValueError, "PCollection must be an numpy array or a torch Tensor"):
+      with TestPipeline() as pipeline:
+        examples = [1, 5, 3, 10]
+
+        state_dict = OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
+                                  ('linear.bias', torch.Tensor([0.5]))])
+        path = os.path.join(self.tmpdir, 'my_state_dict_path')
+        torch.save(state_dict, path)
+
+        input_dim = 1
+        output_dim = 1
+
+        model_loader = PytorchModelLoader(
+            input_dim=input_dim,
+            state_dict_path=path,
+            model_class=PytorchLinearRegression(input_dim, output_dim))
+
+        pcoll = pipeline | 'start' >> beam.Create(examples)
+        # pylint: disable=expression-not-assigned
+        pcoll | base.RunInference(model_loader)

Review Comment:
   Just curious. How long does this test take to run?
   
   This seems like a larger integration test.  I wonder if there's a simpler unit test.
   
   Just to be clear, I like the test, I'm just curious if it's heavy or not.



##########
sdks/python/apache_beam/ml/inference/pytorch_impl.py:
##########
@@ -0,0 +1,119 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+from typing import Iterable
+from typing import List
+from typing import Union
+
+import numpy as np
+import pickle
+import torch
+from torch import nn
+
+from apache_beam.ml.inference.base import InferenceRunner
+from apache_beam.ml.inference.base import ModelLoader
+
+
+class PytorchInferenceRunner(InferenceRunner):
+  """
+  Implements Pytorch inference method
+  """
+  def __init__(self, input_dim: int, device: torch.device):
+    self._input_dim = input_dim
+    self._device = device
+
+  def run_inference(
+      self, batch: List[Union[np.ndarray, torch.Tensor]],
+      model: nn.Module) -> Iterable[torch.Tensor]:
+    """
+    Runs inferences on a batch of examples and returns an Iterable of
+    Predictions."""
+    if batch:
+      if isinstance(batch[0], np.ndarray):
+        batch = torch.Tensor(batch)
+      elif isinstance(batch[0], torch.Tensor):
+        batch = torch.stack(batch)
+      else:
+        raise ValueError("PCollection must be an numpy array or a torch Tensor")

Review Comment:
   Why not also return the type in the error, so the user can see what was wrong?



##########
sdks/python/apache_beam/ml/inference/pytorch_impl.py:
##########
@@ -0,0 +1,119 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+from typing import Iterable
+from typing import List
+from typing import Union
+
+import numpy as np
+import pickle
+import torch
+from torch import nn
+
+from apache_beam.ml.inference.base import InferenceRunner
+from apache_beam.ml.inference.base import ModelLoader
+
+
+class PytorchInferenceRunner(InferenceRunner):
+  """
+  Implements Pytorch inference method
+  """
+  def __init__(self, input_dim: int, device: torch.device):
+    self._input_dim = input_dim
+    self._device = device
+
+  def run_inference(
+      self, batch: List[Union[np.ndarray, torch.Tensor]],
+      model: nn.Module) -> Iterable[torch.Tensor]:
+    """
+    Runs inferences on a batch of examples and returns an Iterable of
+    Predictions."""
+    if batch:
+      if isinstance(batch[0], np.ndarray):
+        batch = torch.Tensor(batch)
+      elif isinstance(batch[0], torch.Tensor):
+        batch = torch.stack(batch)
+      else:
+        raise ValueError("PCollection must be an numpy array or a torch Tensor")
+
+    if batch.device != self._device:
+      batch = batch.to(self._device)
+    return model(batch)
+
+  def get_num_bytes(self, batch: List[torch.Tensor]) -> int:
+    """Returns the number of bytes of data for a batch."""
+    total_size = 0
+    for el in batch:
+      if isinstance(el, np.ndarray):
+        total_size += el.itemsize
+      elif isinstance(el, torch.Tensor):
+        total_size += el.element_size()
+      else:
+        total_size += len(pickle.dumps(el))
+    return total_size
+
+  def get_metrics_namespace(self) -> str:
+    """
+    Returns a namespace for metrics collected by the RunInference transform.
+    """
+    return 'RunInferencePytorch'
+
+
+class PytorchModelLoader(ModelLoader):
+  """Loads a Pytorch Model."""
+  def __init__(
+      self,
+      input_dim: int,
+      state_dict_path: str,

Review Comment:
   is state_dict_path another way of just saying the saved model?



##########
sdks/python/apache_beam/ml/inference/pytorch_impl.py:
##########
@@ -0,0 +1,119 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+from typing import Iterable
+from typing import List
+from typing import Union
+
+import numpy as np
+import pickle
+import torch
+from torch import nn
+
+from apache_beam.ml.inference.base import InferenceRunner
+from apache_beam.ml.inference.base import ModelLoader
+
+
+class PytorchInferenceRunner(InferenceRunner):
+  """
+  Implements Pytorch inference method
+  """
+  def __init__(self, input_dim: int, device: torch.device):

Review Comment:
   lengthen the name input_dim.  Is this input_dimension?



##########
sdks/python/apache_beam/ml/inference/pytorch_impl.py:
##########
@@ -0,0 +1,119 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+from typing import Iterable
+from typing import List
+from typing import Union
+
+import numpy as np
+import pickle
+import torch
+from torch import nn

Review Comment:
   I prefer just imorting torch, since  I find torch.nn easier to understand than nn.  But if a torch user would prefer this shorthand, then maybe it's fine.



##########
sdks/python/apache_beam/ml/inference/pytorch_impl.py:
##########
@@ -0,0 +1,119 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+from typing import Iterable
+from typing import List
+from typing import Union
+
+import numpy as np
+import pickle
+import torch
+from torch import nn
+
+from apache_beam.ml.inference.base import InferenceRunner
+from apache_beam.ml.inference.base import ModelLoader
+
+
+class PytorchInferenceRunner(InferenceRunner):
+  """
+  Implements Pytorch inference method
+  """
+  def __init__(self, input_dim: int, device: torch.device):
+    self._input_dim = input_dim
+    self._device = device
+
+  def run_inference(
+      self, batch: List[Union[np.ndarray, torch.Tensor]],
+      model: nn.Module) -> Iterable[torch.Tensor]:
+    """
+    Runs inferences on a batch of examples and returns an Iterable of
+    Predictions."""
+    if batch:
+      if isinstance(batch[0], np.ndarray):
+        batch = torch.Tensor(batch)

Review Comment:
   We were talking today, and the ideal would be to have a transform to convert to the right type of data before we we even get this deep into transform.
   
   Why do we always need to convert? Can the batch input type just be the same as what the model expects?



##########
sdks/python/apache_beam/ml/inference/pytorch_impl.py:
##########
@@ -0,0 +1,119 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+from typing import Iterable
+from typing import List
+from typing import Union
+
+import numpy as np
+import pickle
+import torch
+from torch import nn
+
+from apache_beam.ml.inference.base import InferenceRunner
+from apache_beam.ml.inference.base import ModelLoader
+
+
+class PytorchInferenceRunner(InferenceRunner):
+  """
+  Implements Pytorch inference method
+  """
+  def __init__(self, input_dim: int, device: torch.device):
+    self._input_dim = input_dim
+    self._device = device
+
+  def run_inference(
+      self, batch: List[Union[np.ndarray, torch.Tensor]],
+      model: nn.Module) -> Iterable[torch.Tensor]:
+    """
+    Runs inferences on a batch of examples and returns an Iterable of
+    Predictions."""
+    if batch:

Review Comment:
   I prefer
   if not batch:
     return []



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org