You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bh...@apache.org on 2022/10/20 00:16:20 UTC

[beam] branch master updated: Add PytorchBatchConverter (#23296)

This is an automated email from the ASF dual-hosted git repository.

bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new b4af23d8414 Add PytorchBatchConverter (#23296)
b4af23d8414 is described below

commit b4af23d8414c65b8af683a726fead5158d300477
Author: Andy Ye <an...@gmail.com>
AuthorDate: Wed Oct 19 20:16:09 2022 -0400

    Add PytorchBatchConverter (#23296)
    
    * Add PytorchBatchConverter
    
    * Refactor to pytorch_type_compatibility files
    
    * Fix test syntax and imports; Lint
    
    * Add main
    
    * Fix torch.cat
    
    * Update sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py
    
    Co-authored-by: Brian Hulette <hu...@gmail.com>
---
 .../typehints/pytorch_type_compatibility.py        | 140 +++++++++++++++++++++
 .../typehints/pytorch_type_compatibility_test.py   | 138 ++++++++++++++++++++
 2 files changed, 278 insertions(+)

diff --git a/sdks/python/apache_beam/typehints/pytorch_type_compatibility.py b/sdks/python/apache_beam/typehints/pytorch_type_compatibility.py
new file mode 100644
index 00000000000..fbecb6d5105
--- /dev/null
+++ b/sdks/python/apache_beam/typehints/pytorch_type_compatibility.py
@@ -0,0 +1,140 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Optional
+
+import torch
+from apache_beam.typehints import typehints
+from apache_beam.typehints.batch import BatchConverter
+from apache_beam.typehints.batch import N
+
+
+class PytorchBatchConverter(BatchConverter):
+  def __init__(
+      self,
+      batch_type,
+      element_type,
+      dtype,
+      element_shape=(),
+      partition_dimension=0):
+    super().__init__(batch_type, element_type)
+    self.dtype = dtype
+    self.element_shape = element_shape
+    self.partition_dimension = partition_dimension
+
+  @staticmethod
+  @BatchConverter.register
+  def from_typehints(element_type,
+                     batch_type) -> Optional['PytorchBatchConverter']:
+    if not isinstance(element_type, PytorchTypeHint.PytorchTypeConstraint):
+      try:
+        element_type = PytorchTensor[element_type, ()]
+      except TypeError:
+        # TODO: Is there a better way to detect if element_type is a dtype?
+        return None
+
+    if not isinstance(batch_type, PytorchTypeHint.PytorchTypeConstraint):
+      if not batch_type == torch.Tensor:
+        # TODO: Include explanation for mismatch?
+        return None
+      batch_type = PytorchTensor[element_type.dtype, (N, )]
+
+    if not batch_type.dtype == element_type.dtype:
+      return None
+    batch_shape = list(batch_type.shape)
+    partition_dimension = batch_shape.index(N)
+    batch_shape.pop(partition_dimension)
+    if not tuple(batch_shape) == element_type.shape:
+      return None
+
+    return PytorchBatchConverter(
+        batch_type,
+        element_type,
+        batch_type.dtype,
+        element_type.shape,
+        partition_dimension)
+
+  def produce_batch(self, elements):
+    return torch.stack(elements, dim=self.partition_dimension)
+
+  def explode_batch(self, batch):
+    """Convert an instance of B to Generator[E]."""
+    yield from torch.swapaxes(batch, self.partition_dimension, 0)
+
+  def combine_batches(self, batches):
+    return torch.cat(batches, dim=self.partition_dimension)
+
+  def get_length(self, batch):
+    return batch.size(dim=self.partition_dimension)
+
+  def estimate_byte_size(self, batch):
+    return batch.nelement() * batch.element_size()
+
+
+class PytorchTypeHint():
+  class PytorchTypeConstraint(typehints.TypeConstraint):
+    def __init__(self, dtype, shape=()):
+      self.dtype = dtype
+      self.shape = shape
+
+    def type_check(self, batch):
+      if not isinstance(batch, torch.Tensor):
+        raise TypeError(f"Batch {batch!r} is not an instance of torch.Tensor")
+      if not batch.dtype == self.dtype:
+        raise TypeError(
+            f"Batch {batch!r} does not have expected dtype: {self.dtype!r}")
+
+      for dim in range(len(self.shape)):
+        if not self.shape[dim] == N and not batch.shape[dim] == self.shape[dim]:
+          raise TypeError(
+              f"Batch {batch!r} does not have expected shape: {self.shape!r}")
+
+    def _consistent_with_check_(self, sub):
+      # TODO Check sub against batch type, and element type
+      return True
+
+    def __key(self):
+      return (self.dtype, self.shape)
+
+    def __eq__(self, other) -> bool:
+      if isinstance(other, PytorchTypeHint.PytorchTypeConstraint):
+        return self.__key() == other.__key()
+
+      return NotImplemented
+
+    def __hash__(self) -> int:
+      return hash(self.__key())
+
+    def __repr__(self):
+      if self.shape == (N, ):
+        return f'PytorchTensor[{self.dtype!r}]'
+      else:
+        return f'PytorchTensor[{self.dtype!r}, {self.shape!r}]'
+
+  def __getitem__(self, value):
+    if isinstance(value, tuple):
+      if len(value) == 2:
+        dtype, shape = value
+        return self.PytorchTypeConstraint(dtype, shape=shape)
+      else:
+        raise ValueError
+    else:
+      dtype = value
+      return self.PytorchTypeConstraint(dtype, shape=(N, ))
+
+
+PytorchTensor = PytorchTypeHint()
diff --git a/sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py b/sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py
new file mode 100644
index 00000000000..e851d4679cc
--- /dev/null
+++ b/sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py
@@ -0,0 +1,138 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unit tests for pytorch_type_compabitility."""
+
+import unittest
+
+import pytest
+from parameterized import parameterized
+from parameterized import parameterized_class
+
+from apache_beam.typehints import typehints
+from apache_beam.typehints.batch import BatchConverter
+from apache_beam.typehints.batch import N
+
+# Protect against environments where pytorch library is not available.
+# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
+try:
+  import torch
+  from apache_beam.typehints.pytorch_type_compatibility import PytorchTensor
+except ImportError:
+  raise unittest.SkipTest('PyTorch dependencies are not installed')
+
+
+@parameterized_class([
+    {
+        'batch_typehint': torch.Tensor,
+        'element_typehint': PytorchTensor[torch.int32, ()],
+        'batch': torch.tensor(range(100), dtype=torch.int32)
+    },
+    {
+        'batch_typehint': PytorchTensor[torch.int64, (N, 10)],
+        'element_typehint': PytorchTensor[torch.int64, (10, )],
+        'batch': torch.tensor([list(range(i, i + 10)) for i in range(100)],
+                              dtype=torch.int64),
+    },
+])
+@pytest.mark.uses_pytorch
+class PytorchBatchConverterTest(unittest.TestCase):
+  def create_batch_converter(self):
+    return BatchConverter.from_typehints(
+        element_type=self.element_typehint, batch_type=self.batch_typehint)
+
+  def setUp(self):
+    self.converter = self.create_batch_converter()
+    self.normalized_batch_typehint = typehints.normalize(self.batch_typehint)
+    self.normalized_element_typehint = typehints.normalize(
+        self.element_typehint)
+
+  def equality_check(self, left, right):
+    if isinstance(left, torch.Tensor):
+      self.assertTrue(torch.equal(left, right))
+    else:
+      raise TypeError(f"Encountered unexpected type, left is a {type(left)!r}")
+
+  def test_typehint_validates(self):
+    typehints.validate_composite_type_param(self.batch_typehint, '')
+    typehints.validate_composite_type_param(self.element_typehint, '')
+
+  def test_type_check_batch(self):
+    typehints.check_constraint(self.normalized_batch_typehint, self.batch)
+
+  def test_type_check_element(self):
+    for element in self.converter.explode_batch(self.batch):
+      typehints.check_constraint(self.normalized_element_typehint, element)
+
+  def test_explode_rebatch(self):
+    exploded = list(self.converter.explode_batch(self.batch))
+    rebatched = self.converter.produce_batch(exploded)
+
+    typehints.check_constraint(self.normalized_batch_typehint, rebatched)
+    self.equality_check(self.batch, rebatched)
+
+  def _split_batch_into_n_partitions(self, N):
+    elements = list(self.converter.explode_batch(self.batch))
+
+    # Split elements into N contiguous partitions
+    element_batches = [
+        elements[len(elements) * i // N:len(elements) * (i + 1) // N]
+        for i in range(N)
+    ]
+
+    lengths = [len(element_batch) for element_batch in element_batches]
+    batches = [
+        self.converter.produce_batch(element_batch)
+        for element_batch in element_batches
+    ]
+
+    return batches, lengths
+
+  @parameterized.expand([
+      (2, ),
+      (3, ),
+      (10, ),
+  ])
+  def test_combine_batches(self, N):
+    batches, _ = self._split_batch_into_n_partitions(N)
+
+    # Combine the batches, output should be equivalent to the original batch
+    combined = self.converter.combine_batches(batches)
+
+    self.equality_check(self.batch, combined)
+
+  @parameterized.expand([
+      (2, ),
+      (3, ),
+      (10, ),
+  ])
+  def test_get_length(self, N):
+    batches, lengths = self._split_batch_into_n_partitions(N)
+
+    for batch, expected_length in zip(batches, lengths):
+      self.assertEqual(self.converter.get_length(batch), expected_length)
+
+  def test_equals(self):
+    self.assertTrue(self.converter == self.create_batch_converter())
+    self.assertTrue(self.create_batch_converter() == self.converter)
+
+  def test_hash(self):
+    self.assertEqual(hash(self.create_batch_converter()), hash(self.converter))
+
+
+if __name__ == '__main__':
+  unittest.main()