You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2023/05/11 13:08:01 UTC
[beam] branch master updated: Add max/min batch size to tf model handlers (#26651)
This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 8424480b550 Add max/min batch size to tf model handlers (#26651)
8424480b550 is described below
commit 8424480b5508f41f4945d0b7bdcac1524321edb3
Author: Danny McCormick <da...@google.com>
AuthorDate: Thu May 11 09:07:53 2023 -0400
Add max/min batch size to tf model handlers (#26651)
---
.../ml/inference/tensorflow_inference.py | 20 ++++
.../ml/inference/tensorflow_inference_test.py | 125 ++++++++++++++++++++-
2 files changed, 140 insertions(+), 5 deletions(-)
diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
index 976614e5d46..cfa36f05f61 100644
--- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
@@ -100,6 +100,8 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
load_model_args: Optional[Dict[str, Any]] = None,
custom_weights: str = "",
inference_fn: TensorInferenceFn = default_numpy_inference_fn,
+ min_batch_size: Optional[int] = None,
+ max_batch_size: Optional[int] = None,
**kwargs):
"""Implementation of the ModelHandler interface for Tensorflow.
@@ -134,6 +136,11 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
self._env_vars = kwargs.get('env_vars', {})
self._load_model_args = {} if not load_model_args else load_model_args
self._custom_weights = custom_weights
+ self._batching_kwargs = {}
+ if min_batch_size is not None:
+ self._batching_kwargs['min_batch_size'] = min_batch_size
+ if max_batch_size is not None:
+ self._batching_kwargs['max_batch_size'] = max_batch_size
def load_model(self) -> tf.Module:
"""Loads and initializes a Tensorflow model for processing."""
@@ -193,6 +200,9 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
pass
+ def batch_elements_kwargs(self):
+ return self._batching_kwargs
+
class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult,
tf.Module]):
@@ -205,6 +215,8 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult,
load_model_args: Optional[Dict[str, Any]] = None,
custom_weights: str = "",
inference_fn: TensorInferenceFn = default_tensor_inference_fn,
+ min_batch_size: Optional[int] = None,
+ max_batch_size: Optional[int] = None,
**kwargs):
"""Implementation of the ModelHandler interface for Tensorflow.
@@ -240,6 +252,11 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult,
self._env_vars = kwargs.get('env_vars', {})
self._load_model_args = {} if not load_model_args else load_model_args
self._custom_weights = custom_weights
+ self._batching_kwargs = {}
+ if min_batch_size is not None:
+ self._batching_kwargs['min_batch_size'] = min_batch_size
+ if max_batch_size is not None:
+ self._batching_kwargs['max_batch_size'] = max_batch_size
def load_model(self) -> tf.Module:
"""Loads and initializes a tensorflow model for processing."""
@@ -299,3 +316,6 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult,
def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
pass
+
+ def batch_elements_kwargs(self):
+ return self._batching_kwargs
diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py
index 8f0f6d06d40..4651f672591 100644
--- a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py
@@ -17,6 +17,9 @@
# pytype: skip-file
+import os
+import shutil
+import tempfile
import unittest
from typing import Any
from typing import Dict
@@ -30,12 +33,17 @@ import pytest
try:
import tensorflow as tf
+ import apache_beam as beam
+ from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.ml.inference.sklearn_inference_test import _compare_prediction_result
- from apache_beam.ml.inference.base import KeyedModelHandler, PredictionResult
+ from apache_beam.ml.inference.base import KeyedModelHandler, PredictionResult, RunInference
from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerNumpy, TFModelHandlerTensor
from apache_beam.ml.inference import tensorflow_inference, utils
+ from apache_beam.testing.util import assert_that, equal_to
except ImportError:
- raise unittest.SkipTest('Tensorflow dependencies are not installed')
+ raise unittest.SkipTest(
+ 'Tensorflow dependencies are not installed. ' +
+ 'Make sure you have both tensorflow and tensorflow_hub installed.')
class FakeTFNumpyModel:
@@ -50,8 +58,14 @@ class FakeTFTensorModel:
return tf.math.multiply(input, 10)
+def _create_mult2_model():
+ inputs = tf.keras.Input(shape=(3))
+ outputs = tf.keras.layers.Lambda(lambda x: x * 2, dtype='float32')(inputs)
+ return tf.keras.Model(inputs=inputs, outputs=outputs)
+
+
def _compare_tensor_prediction_result(x, y):
- return tf.math.equal(x.inference, y.inference)
+ return tf.reduce_all(tf.math.equal(x.inference, y.inference))
def fake_inference_fn(
@@ -65,6 +79,12 @@ def fake_inference_fn(
@pytest.mark.uses_tf
class TFRunInferenceTest(unittest.TestCase):
+ def setUp(self):
+ self.tmpdir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdir)
+
def test_predict_numpy(self):
fake_model = FakeTFNumpyModel()
inference_runner = TFModelHandlerNumpy(
@@ -99,6 +119,93 @@ class TFRunInferenceTest(unittest.TestCase):
for actual, expected in zip(inferences, expected_predictions):
self.assertTrue(_compare_tensor_prediction_result(actual, expected))
+ def test_predict_tensor_with_batch_size(self):
+ model = _create_mult2_model()
+ model_path = os.path.join(self.tmpdir, 'mult2')
+ tf.keras.models.save_model(model, model_path)
+ with TestPipeline() as pipeline:
+
+ def fake_batching_inference_fn(
+ model: tf.Module,
+ batch: Union[Sequence[numpy.ndarray], Sequence[tf.Tensor]],
+ inference_args: Dict[str, Any],
+ model_id: Optional[str] = None) -> Iterable[PredictionResult]:
+ if len(batch) != 2:
+ raise Exception(
+ f'Expected batch of size 2, received batch of size {len(batch)}')
+ batch = tf.stack(batch, axis=0)
+ predictions = model(batch)
+ return utils._convert_to_result(batch, predictions, model_id)
+
+ model_handler = TFModelHandlerTensor(
+ model_uri=model_path,
+ inference_fn=fake_batching_inference_fn,
+ min_batch_size=2,
+ max_batch_size=2)
+ examples = [
+ tf.convert_to_tensor(numpy.array([1.1, 2.2, 3.3], dtype='float32')),
+ tf.convert_to_tensor(
+ numpy.array([10.1, 20.2, 30.3], dtype='float32')),
+ tf.convert_to_tensor(
+ numpy.array([100.1, 200.2, 300.3], dtype='float32')),
+ tf.convert_to_tensor(
+ numpy.array([200.1, 300.2, 400.3], dtype='float32')),
+ ]
+ expected_predictions = [
+ PredictionResult(ex, pred) for ex,
+ pred in zip(examples, [tf.math.multiply(n, 2) for n in examples])
+ ]
+
+ pcoll = pipeline | 'start' >> beam.Create(examples)
+ predictions = pcoll | RunInference(model_handler)
+ assert_that(
+ predictions,
+ equal_to(
+ expected_predictions,
+ equals_fn=_compare_tensor_prediction_result))
+
+ def test_predict_numpy_with_batch_size(self):
+ model = _create_mult2_model()
+ model_path = os.path.join(self.tmpdir, 'mult2_numpy')
+ tf.keras.models.save_model(model, model_path)
+ with TestPipeline() as pipeline:
+
+ def fake_batching_inference_fn(
+ model: tf.Module,
+ batch: Sequence[numpy.ndarray],
+ inference_args: Dict[str, Any],
+ model_id: Optional[str] = None) -> Iterable[PredictionResult]:
+ if len(batch) != 2:
+ raise Exception(
+ f'Expected batch of size 2, received batch of size {len(batch)}')
+ vectorized_batch = numpy.stack(batch, axis=0)
+ predictions = model.predict(vectorized_batch, **inference_args)
+ return utils._convert_to_result(batch, predictions, model_id)
+
+ model_handler = TFModelHandlerNumpy(
+ model_uri=model_path,
+ inference_fn=fake_batching_inference_fn,
+ min_batch_size=2,
+ max_batch_size=2)
+ examples = [
+ numpy.array([1.1, 2.2, 3.3], dtype='float32'),
+ numpy.array([10.1, 20.2, 30.3], dtype='float32'),
+ numpy.array([100.1, 200.2, 300.3], dtype='float32'),
+ numpy.array([200.1, 300.2, 400.3], dtype='float32'),
+ ]
+ expected_predictions = [
+ PredictionResult(ex, pred) for ex,
+ pred in zip(examples, [numpy.multiply(n, 2) for n in examples])
+ ]
+
+ pcoll = pipeline | 'start' >> beam.Create(examples)
+ predictions = pcoll | RunInference(model_handler)
+ assert_that(
+ predictions,
+ equal_to(
+ expected_predictions,
+ equals_fn=_compare_tensor_prediction_result))
+
def test_predict_tensor_with_args(self):
fake_model = FakeTFTensorModel()
inference_runner = TFModelHandlerTensor(
@@ -161,11 +268,20 @@ class TFRunInferenceTest(unittest.TestCase):
for actual, expected in zip(inferences, expected_predictions):
self.assertTrue(_compare_tensor_prediction_result(actual[1], expected[1]))
+
+@pytest.mark.uses_tf
+class TFRunInferenceTestWithMocks(unittest.TestCase):
+ def setUp(self):
+ self._load_model = tensorflow_inference._load_model
+ tensorflow_inference._load_model = unittest.mock.MagicMock()
+
+ def tearDown(self):
+ tensorflow_inference._load_model = self._load_model
+
def test_load_model_args(self):
load_model_args = {compile: False, 'custom_objects': {'optimizer': 1}}
model_handler = TFModelHandlerNumpy(
"dummy_model", load_model_args=load_model_args)
- tensorflow_inference._load_model = unittest.mock.MagicMock()
model_handler.load_model()
tensorflow_inference._load_model.assert_called_with(
"dummy_model", "", load_model_args)
@@ -176,7 +292,6 @@ class TFRunInferenceTest(unittest.TestCase):
"dummy_model",
custom_weights="dummy_weights",
load_model_args=load_model_args)
- tensorflow_inference._load_model = unittest.mock.MagicMock()
model_handler.load_model()
tensorflow_inference._load_model.assert_called_with(
"dummy_model", "dummy_weights", load_model_args)