You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2023/05/11 13:08:01 UTC

[beam] branch master updated: Add max/min batch size to tf model handlers (#26651)

This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 8424480b550 Add max/min batch size to tf model handlers (#26651)
8424480b550 is described below

commit 8424480b5508f41f4945d0b7bdcac1524321edb3
Author: Danny McCormick <da...@google.com>
AuthorDate: Thu May 11 09:07:53 2023 -0400

    Add max/min batch size to tf model handlers (#26651)
---
 .../ml/inference/tensorflow_inference.py           |  20 ++++
 .../ml/inference/tensorflow_inference_test.py      | 125 ++++++++++++++++++++-
 2 files changed, 140 insertions(+), 5 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
index 976614e5d46..cfa36f05f61 100644
--- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
@@ -100,6 +100,8 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
       load_model_args: Optional[Dict[str, Any]] = None,
       custom_weights: str = "",
       inference_fn: TensorInferenceFn = default_numpy_inference_fn,
+      min_batch_size: Optional[int] = None,
+      max_batch_size: Optional[int] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for Tensorflow.
 
@@ -134,6 +136,11 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
     self._env_vars = kwargs.get('env_vars', {})
     self._load_model_args = {} if not load_model_args else load_model_args
     self._custom_weights = custom_weights
+    self._batching_kwargs = {}
+    if min_batch_size is not None:
+      self._batching_kwargs['min_batch_size'] = min_batch_size
+    if max_batch_size is not None:
+      self._batching_kwargs['max_batch_size'] = max_batch_size
 
   def load_model(self) -> tf.Module:
     """Loads and initializes a Tensorflow model for processing."""
@@ -193,6 +200,9 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
   def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
     pass
 
+  def batch_elements_kwargs(self):
+    return self._batching_kwargs
+
 
 class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult,
                                         tf.Module]):
@@ -205,6 +215,8 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult,
       load_model_args: Optional[Dict[str, Any]] = None,
       custom_weights: str = "",
       inference_fn: TensorInferenceFn = default_tensor_inference_fn,
+      min_batch_size: Optional[int] = None,
+      max_batch_size: Optional[int] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for Tensorflow.
 
@@ -240,6 +252,11 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult,
     self._env_vars = kwargs.get('env_vars', {})
     self._load_model_args = {} if not load_model_args else load_model_args
     self._custom_weights = custom_weights
+    self._batching_kwargs = {}
+    if min_batch_size is not None:
+      self._batching_kwargs['min_batch_size'] = min_batch_size
+    if max_batch_size is not None:
+      self._batching_kwargs['max_batch_size'] = max_batch_size
 
   def load_model(self) -> tf.Module:
     """Loads and initializes a tensorflow model for processing."""
@@ -299,3 +316,6 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult,
 
   def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
     pass
+
+  def batch_elements_kwargs(self):
+    return self._batching_kwargs
diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py
index 8f0f6d06d40..4651f672591 100644
--- a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py
@@ -17,6 +17,9 @@
 
 # pytype: skip-file
 
+import os
+import shutil
+import tempfile
 import unittest
 from typing import Any
 from typing import Dict
@@ -30,12 +33,17 @@ import pytest
 
 try:
   import tensorflow as tf
+  import apache_beam as beam
+  from apache_beam.testing.test_pipeline import TestPipeline
   from apache_beam.ml.inference.sklearn_inference_test import _compare_prediction_result
-  from apache_beam.ml.inference.base import KeyedModelHandler, PredictionResult
+  from apache_beam.ml.inference.base import KeyedModelHandler, PredictionResult, RunInference
   from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerNumpy, TFModelHandlerTensor
   from apache_beam.ml.inference import tensorflow_inference, utils
+  from apache_beam.testing.util import assert_that, equal_to
 except ImportError:
-  raise unittest.SkipTest('Tensorflow dependencies are not installed')
+  raise unittest.SkipTest(
+      'Tensorflow dependencies are not installed. ' +
+      'Make sure you have both tensorflow and tensorflow_hub installed.')
 
 
 class FakeTFNumpyModel:
@@ -50,8 +58,14 @@ class FakeTFTensorModel:
     return tf.math.multiply(input, 10)
 
 
+def _create_mult2_model():
+  inputs = tf.keras.Input(shape=(3))
+  outputs = tf.keras.layers.Lambda(lambda x: x * 2, dtype='float32')(inputs)
+  return tf.keras.Model(inputs=inputs, outputs=outputs)
+
+
 def _compare_tensor_prediction_result(x, y):
-  return tf.math.equal(x.inference, y.inference)
+  return tf.reduce_all(tf.math.equal(x.inference, y.inference))
 
 
 def fake_inference_fn(
@@ -65,6 +79,12 @@ def fake_inference_fn(
 
 @pytest.mark.uses_tf
 class TFRunInferenceTest(unittest.TestCase):
+  def setUp(self):
+    self.tmpdir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    shutil.rmtree(self.tmpdir)
+
   def test_predict_numpy(self):
     fake_model = FakeTFNumpyModel()
     inference_runner = TFModelHandlerNumpy(
@@ -99,6 +119,93 @@ class TFRunInferenceTest(unittest.TestCase):
     for actual, expected in zip(inferences, expected_predictions):
       self.assertTrue(_compare_tensor_prediction_result(actual, expected))
 
+  def test_predict_tensor_with_batch_size(self):
+    model = _create_mult2_model()
+    model_path = os.path.join(self.tmpdir, 'mult2')
+    tf.keras.models.save_model(model, model_path)
+    with TestPipeline() as pipeline:
+
+      def fake_batching_inference_fn(
+          model: tf.Module,
+          batch: Union[Sequence[numpy.ndarray], Sequence[tf.Tensor]],
+          inference_args: Dict[str, Any],
+          model_id: Optional[str] = None) -> Iterable[PredictionResult]:
+        if len(batch) != 2:
+          raise Exception(
+              f'Expected batch of size 2, received batch of size {len(batch)}')
+        batch = tf.stack(batch, axis=0)
+        predictions = model(batch)
+        return utils._convert_to_result(batch, predictions, model_id)
+
+      model_handler = TFModelHandlerTensor(
+          model_uri=model_path,
+          inference_fn=fake_batching_inference_fn,
+          min_batch_size=2,
+          max_batch_size=2)
+      examples = [
+          tf.convert_to_tensor(numpy.array([1.1, 2.2, 3.3], dtype='float32')),
+          tf.convert_to_tensor(
+              numpy.array([10.1, 20.2, 30.3], dtype='float32')),
+          tf.convert_to_tensor(
+              numpy.array([100.1, 200.2, 300.3], dtype='float32')),
+          tf.convert_to_tensor(
+              numpy.array([200.1, 300.2, 400.3], dtype='float32')),
+      ]
+      expected_predictions = [
+          PredictionResult(ex, pred) for ex,
+          pred in zip(examples, [tf.math.multiply(n, 2) for n in examples])
+      ]
+
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      predictions = pcoll | RunInference(model_handler)
+      assert_that(
+          predictions,
+          equal_to(
+              expected_predictions,
+              equals_fn=_compare_tensor_prediction_result))
+
+  def test_predict_numpy_with_batch_size(self):
+    model = _create_mult2_model()
+    model_path = os.path.join(self.tmpdir, 'mult2_numpy')
+    tf.keras.models.save_model(model, model_path)
+    with TestPipeline() as pipeline:
+
+      def fake_batching_inference_fn(
+          model: tf.Module,
+          batch: Sequence[numpy.ndarray],
+          inference_args: Dict[str, Any],
+          model_id: Optional[str] = None) -> Iterable[PredictionResult]:
+        if len(batch) != 2:
+          raise Exception(
+              f'Expected batch of size 2, received batch of size {len(batch)}')
+        vectorized_batch = numpy.stack(batch, axis=0)
+        predictions = model.predict(vectorized_batch, **inference_args)
+        return utils._convert_to_result(batch, predictions, model_id)
+
+      model_handler = TFModelHandlerNumpy(
+          model_uri=model_path,
+          inference_fn=fake_batching_inference_fn,
+          min_batch_size=2,
+          max_batch_size=2)
+      examples = [
+          numpy.array([1.1, 2.2, 3.3], dtype='float32'),
+          numpy.array([10.1, 20.2, 30.3], dtype='float32'),
+          numpy.array([100.1, 200.2, 300.3], dtype='float32'),
+          numpy.array([200.1, 300.2, 400.3], dtype='float32'),
+      ]
+      expected_predictions = [
+          PredictionResult(ex, pred) for ex,
+          pred in zip(examples, [numpy.multiply(n, 2) for n in examples])
+      ]
+
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      predictions = pcoll | RunInference(model_handler)
+      assert_that(
+          predictions,
+          equal_to(
+              expected_predictions,
+              equals_fn=_compare_tensor_prediction_result))
+
   def test_predict_tensor_with_args(self):
     fake_model = FakeTFTensorModel()
     inference_runner = TFModelHandlerTensor(
@@ -161,11 +268,20 @@ class TFRunInferenceTest(unittest.TestCase):
     for actual, expected in zip(inferences, expected_predictions):
       self.assertTrue(_compare_tensor_prediction_result(actual[1], expected[1]))
 
+
+@pytest.mark.uses_tf
+class TFRunInferenceTestWithMocks(unittest.TestCase):
+  def setUp(self):
+    self._load_model = tensorflow_inference._load_model
+    tensorflow_inference._load_model = unittest.mock.MagicMock()
+
+  def tearDown(self):
+    tensorflow_inference._load_model = self._load_model
+
   def test_load_model_args(self):
     load_model_args = {compile: False, 'custom_objects': {'optimizer': 1}}
     model_handler = TFModelHandlerNumpy(
         "dummy_model", load_model_args=load_model_args)
-    tensorflow_inference._load_model = unittest.mock.MagicMock()
     model_handler.load_model()
     tensorflow_inference._load_model.assert_called_with(
         "dummy_model", "", load_model_args)
@@ -176,7 +292,6 @@ class TFRunInferenceTest(unittest.TestCase):
         "dummy_model",
         custom_weights="dummy_weights",
         load_model_args=load_model_args)
-    tensorflow_inference._load_model = unittest.mock.MagicMock()
     model_handler.load_model()
     tensorflow_inference._load_model.assert_called_with(
         "dummy_model", "dummy_weights", load_model_args)