You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bh...@apache.org on 2022/10/20 00:16:20 UTC
[beam] branch master updated: Add PytorchBatchConverter (#23296)
This is an automated email from the ASF dual-hosted git repository.
bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new b4af23d8414 Add PytorchBatchConverter (#23296)
b4af23d8414 is described below
commit b4af23d8414c65b8af683a726fead5158d300477
Author: Andy Ye <an...@gmail.com>
AuthorDate: Wed Oct 19 20:16:09 2022 -0400
Add PytorchBatchConverter (#23296)
* Add PytorchBatchConverter
* Refactor to pytorch_type_compatibility files
* Fix test syntax and imports; Lint
* Add main
* Fix torch.cat
* Update sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py
Co-authored-by: Brian Hulette <hu...@gmail.com>
---
.../typehints/pytorch_type_compatibility.py | 140 +++++++++++++++++++++
.../typehints/pytorch_type_compatibility_test.py | 138 ++++++++++++++++++++
2 files changed, 278 insertions(+)
diff --git a/sdks/python/apache_beam/typehints/pytorch_type_compatibility.py b/sdks/python/apache_beam/typehints/pytorch_type_compatibility.py
new file mode 100644
index 00000000000..fbecb6d5105
--- /dev/null
+++ b/sdks/python/apache_beam/typehints/pytorch_type_compatibility.py
@@ -0,0 +1,140 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Optional
+
+import torch
+from apache_beam.typehints import typehints
+from apache_beam.typehints.batch import BatchConverter
+from apache_beam.typehints.batch import N
+
+
+class PytorchBatchConverter(BatchConverter):
+ def __init__(
+ self,
+ batch_type,
+ element_type,
+ dtype,
+ element_shape=(),
+ partition_dimension=0):
+ super().__init__(batch_type, element_type)
+ self.dtype = dtype
+ self.element_shape = element_shape
+ self.partition_dimension = partition_dimension
+
+ @staticmethod
+ @BatchConverter.register
+ def from_typehints(element_type,
+ batch_type) -> Optional['PytorchBatchConverter']:
+ if not isinstance(element_type, PytorchTypeHint.PytorchTypeConstraint):
+ try:
+ element_type = PytorchTensor[element_type, ()]
+ except TypeError:
+ # TODO: Is there a better way to detect if element_type is a dtype?
+ return None
+
+ if not isinstance(batch_type, PytorchTypeHint.PytorchTypeConstraint):
+ if not batch_type == torch.Tensor:
+ # TODO: Include explanation for mismatch?
+ return None
+ batch_type = PytorchTensor[element_type.dtype, (N, )]
+
+ if not batch_type.dtype == element_type.dtype:
+ return None
+ batch_shape = list(batch_type.shape)
+ partition_dimension = batch_shape.index(N)
+ batch_shape.pop(partition_dimension)
+ if not tuple(batch_shape) == element_type.shape:
+ return None
+
+ return PytorchBatchConverter(
+ batch_type,
+ element_type,
+ batch_type.dtype,
+ element_type.shape,
+ partition_dimension)
+
+ def produce_batch(self, elements):
+ return torch.stack(elements, dim=self.partition_dimension)
+
+ def explode_batch(self, batch):
+ """Convert an instance of B to Generator[E]."""
+ yield from torch.swapaxes(batch, self.partition_dimension, 0)
+
+ def combine_batches(self, batches):
+ return torch.cat(batches, dim=self.partition_dimension)
+
+ def get_length(self, batch):
+ return batch.size(dim=self.partition_dimension)
+
+ def estimate_byte_size(self, batch):
+ return batch.nelement() * batch.element_size()
+
+
+class PytorchTypeHint():
+ class PytorchTypeConstraint(typehints.TypeConstraint):
+ def __init__(self, dtype, shape=()):
+ self.dtype = dtype
+ self.shape = shape
+
+ def type_check(self, batch):
+ if not isinstance(batch, torch.Tensor):
+ raise TypeError(f"Batch {batch!r} is not an instance of torch.Tensor")
+ if not batch.dtype == self.dtype:
+ raise TypeError(
+ f"Batch {batch!r} does not have expected dtype: {self.dtype!r}")
+
+ for dim in range(len(self.shape)):
+ if not self.shape[dim] == N and not batch.shape[dim] == self.shape[dim]:
+ raise TypeError(
+ f"Batch {batch!r} does not have expected shape: {self.shape!r}")
+
+ def _consistent_with_check_(self, sub):
+ # TODO Check sub against batch type, and element type
+ return True
+
+ def __key(self):
+ return (self.dtype, self.shape)
+
+ def __eq__(self, other) -> bool:
+ if isinstance(other, PytorchTypeHint.PytorchTypeConstraint):
+ return self.__key() == other.__key()
+
+ return NotImplemented
+
+ def __hash__(self) -> int:
+ return hash(self.__key())
+
+ def __repr__(self):
+ if self.shape == (N, ):
+ return f'PytorchTensor[{self.dtype!r}]'
+ else:
+ return f'PytorchTensor[{self.dtype!r}, {self.shape!r}]'
+
+ def __getitem__(self, value):
+ if isinstance(value, tuple):
+ if len(value) == 2:
+ dtype, shape = value
+ return self.PytorchTypeConstraint(dtype, shape=shape)
+ else:
+ raise ValueError
+ else:
+ dtype = value
+ return self.PytorchTypeConstraint(dtype, shape=(N, ))
+
+
+PytorchTensor = PytorchTypeHint()
diff --git a/sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py b/sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py
new file mode 100644
index 00000000000..e851d4679cc
--- /dev/null
+++ b/sdks/python/apache_beam/typehints/pytorch_type_compatibility_test.py
@@ -0,0 +1,138 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unit tests for pytorch_type_compabitility."""
+
+import unittest
+
+import pytest
+from parameterized import parameterized
+from parameterized import parameterized_class
+
+from apache_beam.typehints import typehints
+from apache_beam.typehints.batch import BatchConverter
+from apache_beam.typehints.batch import N
+
+# Protect against environments where pytorch library is not available.
+# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
+try:
+ import torch
+ from apache_beam.typehints.pytorch_type_compatibility import PytorchTensor
+except ImportError:
+ raise unittest.SkipTest('PyTorch dependencies are not installed')
+
+
+@parameterized_class([
+ {
+ 'batch_typehint': torch.Tensor,
+ 'element_typehint': PytorchTensor[torch.int32, ()],
+ 'batch': torch.tensor(range(100), dtype=torch.int32)
+ },
+ {
+ 'batch_typehint': PytorchTensor[torch.int64, (N, 10)],
+ 'element_typehint': PytorchTensor[torch.int64, (10, )],
+ 'batch': torch.tensor([list(range(i, i + 10)) for i in range(100)],
+ dtype=torch.int64),
+ },
+])
+@pytest.mark.uses_pytorch
+class PytorchBatchConverterTest(unittest.TestCase):
+ def create_batch_converter(self):
+ return BatchConverter.from_typehints(
+ element_type=self.element_typehint, batch_type=self.batch_typehint)
+
+ def setUp(self):
+ self.converter = self.create_batch_converter()
+ self.normalized_batch_typehint = typehints.normalize(self.batch_typehint)
+ self.normalized_element_typehint = typehints.normalize(
+ self.element_typehint)
+
+ def equality_check(self, left, right):
+ if isinstance(left, torch.Tensor):
+ self.assertTrue(torch.equal(left, right))
+ else:
+ raise TypeError(f"Encountered unexpected type, left is a {type(left)!r}")
+
+ def test_typehint_validates(self):
+ typehints.validate_composite_type_param(self.batch_typehint, '')
+ typehints.validate_composite_type_param(self.element_typehint, '')
+
+ def test_type_check_batch(self):
+ typehints.check_constraint(self.normalized_batch_typehint, self.batch)
+
+ def test_type_check_element(self):
+ for element in self.converter.explode_batch(self.batch):
+ typehints.check_constraint(self.normalized_element_typehint, element)
+
+ def test_explode_rebatch(self):
+ exploded = list(self.converter.explode_batch(self.batch))
+ rebatched = self.converter.produce_batch(exploded)
+
+ typehints.check_constraint(self.normalized_batch_typehint, rebatched)
+ self.equality_check(self.batch, rebatched)
+
+ def _split_batch_into_n_partitions(self, N):
+ elements = list(self.converter.explode_batch(self.batch))
+
+ # Split elements into N contiguous partitions
+ element_batches = [
+ elements[len(elements) * i // N:len(elements) * (i + 1) // N]
+ for i in range(N)
+ ]
+
+ lengths = [len(element_batch) for element_batch in element_batches]
+ batches = [
+ self.converter.produce_batch(element_batch)
+ for element_batch in element_batches
+ ]
+
+ return batches, lengths
+
+ @parameterized.expand([
+ (2, ),
+ (3, ),
+ (10, ),
+ ])
+ def test_combine_batches(self, N):
+ batches, _ = self._split_batch_into_n_partitions(N)
+
+ # Combine the batches, output should be equivalent to the original batch
+ combined = self.converter.combine_batches(batches)
+
+ self.equality_check(self.batch, combined)
+
+ @parameterized.expand([
+ (2, ),
+ (3, ),
+ (10, ),
+ ])
+ def test_get_length(self, N):
+ batches, lengths = self._split_batch_into_n_partitions(N)
+
+ for batch, expected_length in zip(batches, lengths):
+ self.assertEqual(self.converter.get_length(batch), expected_length)
+
+ def test_equals(self):
+ self.assertTrue(self.converter == self.create_batch_converter())
+ self.assertTrue(self.create_batch_converter() == self.converter)
+
+ def test_hash(self):
+ self.assertEqual(hash(self.create_batch_converter()), hash(self.converter))
+
+
+if __name__ == '__main__':
+ unittest.main()