You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2023/02/15 20:15:18 UTC

[beam] branch master updated: [Python] Added Tensorflow Model Handler (#25368)

This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 8bf324d7826 [Python] Added Tensorflow Model Handler  (#25368)
8bf324d7826 is described below

commit 8bf324d782651f059f08d575ece309fd9052f3b1
Author: Ritesh Ghorse <ri...@gmail.com>
AuthorDate: Wed Feb 15 15:15:10 2023 -0500

    [Python] Added Tensorflow Model Handler  (#25368)
    
    * go lints
    
    * added tf model handler and tests
    
    * lint and formatting changes
    
    * correct lints
    
    * more lints and formats
    
    * auto formatted with yapf
    
    * rm spare lines
    
    * add readme file
    
    * test requirement file
    
    * add test to gradle
    
    * add test tasks for tf
    
    * unit test
    
    * lints
    
    * updated inferenceFn type
    
    * add tox info for py38
    
    * pylint
    
    * lints
    
    * using tfhub
    
    * added tf model handler and tests
    
    * lint and formatting changes
    
    * correct lints
    
    * more lints and formats
    
    * auto formatted with yapf
    
    * rm spare lines
    
    * merge master
    
    * test requirement file
    
    * add test to gradle
    
    * add test tasks for tf
    
    * unit test
    
    * lints
    
    * updated inferenceFn type
    
    * add tox info for py38
    
    * pylint
    
    * lints
    
    * using tfhub
    
    * tfhub example
    
    * update doc
    
    * sort imports
    
    * resolve pydoc,precommit
    
    * fix import
    
    * fix lint
    
    * address comments
    
    * fix optional inference args
    
    * change to ml bucket
    
    * fix doc
---
 .../apache_beam/examples/inference/README.md       | 107 ++++++++-
 .../inference/tensorflow_imagenet_segmentation.py  | 128 +++++++++++
 .../inference/tensorflow_mnist_classification.py   | 118 ++++++++++
 .../ml/inference/tensorflow_inference.py           | 246 +++++++++++++++++++++
 .../ml/inference/tensorflow_inference_it_test.py   | 114 ++++++++++
 .../ml/inference/tensorflow_inference_test.py      | 146 ++++++++++++
 .../ml/inference/tensorflow_tests_requirements.txt |  21 ++
 sdks/python/pytest.ini                             |   1 +
 sdks/python/test-suites/direct/common.gradle       |  30 ++-
 sdks/python/test-suites/tox/py38/build.gradle      |  12 +
 sdks/python/tox.ini                                |  17 ++
 11 files changed, 938 insertions(+), 2 deletions(-)

diff --git a/sdks/python/apache_beam/examples/inference/README.md b/sdks/python/apache_beam/examples/inference/README.md
index 7d71b1d2826..69cd773593b 100644
--- a/sdks/python/apache_beam/examples/inference/README.md
+++ b/sdks/python/apache_beam/examples/inference/README.md
@@ -32,6 +32,15 @@ because the `apache_beam.examples.inference` module was added in that release.
 pip install apache-beam==2.40.0
 ```
 
+### Tensorflow dependencies
+
+The following installation requirement is for the Tensorflow model handler examples.
+
+The RunInference API supports the Tensorflow framework. To use Tensorflow locally, first install `tensorflow`.
+```
+pip install tensorflow==2.11.0
+```
+
 ### PyTorch dependencies
 
 The following installation requirements are for the files used in these examples.
@@ -417,4 +426,100 @@ python -m apache_beam.examples.inference.onnx_sentiment_classification.py \
 This writes the output to the output file path with contents like:
 ```
 A comedy-drama of nearly epic proportions rooted in a sincere performance by the title character undergoing midlife crisis .;1
-```
\ No newline at end of file
+```
+
+---
+## MNIST digit classification with Tensorflow
+[`tensorflow_mnist_classification.py`](./tensorflow_mnist_classification.py) contains an implementation for a RunInference pipeline that performs image classification on handwritten digits from the [MNIST](https://en.wikipedia.org/wiki/MNIST_database) database.
+
+The pipeline reads rows of pixels corresponding to a digit, performs basic preprocessing(converts the input shape to 28x28), passes the pixels to the trained Tensorflow model with RunInference, and then writes the predictions to a text file.
+
+### Dataset and model for language modeling
+
+To use this transform, you need a dataset and model for language modeling.
+
+1. Create a file named [`INPUT.csv`](gs://apache-beam-ml/testing/inputs/it_mnist_data.csv) that contains labels and pixels to feed into the model. Each row should have comma-separated elements. The first element is the label. All other elements are pixel values. The csv should not have column headers. The content of the file should be similar to the following example:
+```
+1,0,0,0...
+0,0,0,0...
+1,0,0,0...
+4,0,0,0...
+...
+```
+2. Save the trained tensorflow model to a directory `MODEL_DIR` .
+
+
+### Running `tensorflow_mnist_classification.py`
+
+To run the MNIST classification pipeline locally, use the following command:
+```sh
+python -m apache_beam.examples.inference.tensorflow_mnist_classification.py \
+  --input INPUT \
+  --output OUTPUT \
+  --model_path MODEL_DIR
+```
+For example:
+```sh
+python -m apache_beam.examples.inference.tensorflow_mnist_classification.py \
+  --input INPUT.csv \
+  --output predictions.txt \
+  --model_path MODEL_DIR
+```
+
+This writes the output to the `predictions.txt` with contents like:
+```
+1,1
+4,4
+0,0
+7,7
+3,3
+5,5
+...
+```
+Each line has data separated by a comma ",". The first item is the actual label of the digit. The second item is the predicted label of the digit.
+
+---
+## Image segmentation with Tensorflow and TensorflowHub
+
+[`tensorflow_imagenet_segmentation.py`](./tensorflow_imagenet_segmentation.py) contains an implementation for a RunInference pipeline that performs image segementation using the [`mobilenet_v2`]("https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4") architecture from the tensorflow hub.
+
+The pipeline reads images, performs basic preprocessing, passes the images to the Tensorflow implementation of RunInference, and then writes predictions to a text file.
+
+### Dataset and model for image segmentation
+
+To use this transform, you need a dataset and model for image segmentation.
+
+1. Create a directory named `IMAGE_DIR`. Create or download images and put them in this directory. We
+will use the [example image]("https://storage.googleapis.com/download.tensorflow.org/example_images/") on tensorflow.
+2. Create a file named `IMAGE_FILE_NAMES.txt` that names of each of the images in `IMAGE_DIR` that you want to use to run image segmentation. For example:
+```
+grace_hopper.jpg
+```
+3. A tensorflow `MODEL_PATH`, we will use the [mobilenet]("https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4") model.
+4. Note the path to the `OUTPUT` file. This file is used by the pipeline to write the predictions.
+
+### Running `tensorflow_image_segmentation.py`
+
+To run the image segmentation pipeline locally, use the following command:
+```sh
+python -m apache_beam.examples.inference.tensorflow_image_segmentation \
+  --input IMAGE_FILE_NAMES \
+  --image_dir IMAGES_DIR \
+  --output OUTPUT \
+  --model_path MODEL_PATH
+```
+
+For example, if you've followed the naming conventions recommended above:
+```sh
+python -m apache_beam.examples.inference.tensorflow_image_segmentation \
+  --input IMAGE_FILE_NAMES.txt \
+  --image_dir "https://storage.googleapis.com/download.tensorflow.org/example_images/"
+  --output predictions.txt \
+  --model_path "https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4"
+```
+This writes the output to the `predictions.txt` with contents like:
+```
+background
+...
+```
+Each line has a list of predicted label.
diff --git a/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py b/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py
new file mode 100644
index 00000000000..bfdaefe861e
--- /dev/null
+++ b/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py
@@ -0,0 +1,128 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import argparse
+import logging
+from typing import Iterable
+from typing import Iterator
+
+import numpy
+
+import apache_beam as beam
+import tensorflow as tf
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from apache_beam.runners.runner import PipelineResult
+from PIL import Image
+
+
+class PostProcessor(beam.DoFn):
+  """Process the PredictionResult to get the predicted label.
+  Returns predicted label.
+  """
+  def process(self, element: PredictionResult) -> Iterable[str]:
+    predicted_class = numpy.argmax(element.inference[0], axis=-1)
+    labels_path = tf.keras.utils.get_file(
+        'ImageNetLabels.txt',
+        'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'  # pylint: disable=line-too-long
+    )
+    imagenet_labels = numpy.array(open(labels_path).read().splitlines())
+    predicted_class_name = imagenet_labels[predicted_class]
+    return predicted_class_name.title()
+
+
+def parse_known_args(argv):
+  """Parses args for the workflow."""
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--input',
+      dest='input',
+      required=True,
+      help='Path to the text file containing image names.')
+  parser.add_argument(
+      '--output',
+      dest='output',
+      required=True,
+      help='Path to save output predictions.')
+  parser.add_argument(
+      '--model_path',
+      dest='model_path',
+      required=True,
+      help='Path to load the Tensorflow model for Inference.')
+  parser.add_argument(
+      '--image_dir', help='Path to the directory where images are stored.')
+  return parser.parse_known_args(argv)
+
+
+def filter_empty_lines(text: str) -> Iterator[str]:
+  if len(text.strip()) > 0:
+    yield text
+
+
+def read_image(image_name, image_dir):
+  img = tf.keras.utils.get_file(image_name, image_dir + image_name)
+  img = Image.open(img).resize((224, 224))
+  img = numpy.array(img) / 255.0
+  img_tensor = tf.cast(tf.convert_to_tensor(img[...]), dtype=tf.float32)
+  return img_tensor
+
+
+def run(
+    argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult:
+  """
+  Args:
+    argv: Command line arguments defined for this example.
+    save_main_session: Used for internal testing.
+    test_pipeline: Used for internal testing.
+  """
+  known_args, pipeline_args = parse_known_args(argv)
+  pipeline_options = PipelineOptions(pipeline_args)
+  pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
+
+  # In this example we will use the TensorflowHub model URL.
+  model_loader = TFModelHandlerTensor(model_uri=known_args.model_path)
+
+  pipeline = test_pipeline
+  if not test_pipeline:
+    pipeline = beam.Pipeline(options=pipeline_options)
+
+  image = (
+      pipeline
+      | 'ReadImageNames' >> beam.io.ReadFromText(known_args.input)
+      | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines)
+      | "PreProcessInputs" >>
+      beam.Map(lambda image_name: read_image(image_name, known_args.image_dir)))
+
+  predictions = (
+      image
+      | "RunInference" >> RunInference(model_loader)
+      | "PostProcessOutputs" >> beam.ParDo(PostProcessor()))
+
+  _ = predictions | "WriteOutput" >> beam.io.WriteToText(
+      known_args.output, shard_name_template='', append_trailing_newlines=False)
+
+  result = pipeline.run()
+  result.wait_until_finish()
+  return result
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  run()
diff --git a/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py b/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py
new file mode 100644
index 00000000000..174d21b26af
--- /dev/null
+++ b/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py
@@ -0,0 +1,118 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import argparse
+import logging
+from typing import Iterable
+from typing import Tuple
+
+import numpy
+
+import apache_beam as beam
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.tensorflow_inference import ModelType
+from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerNumpy
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from apache_beam.runners.runner import PipelineResult
+
+
+def process_input(row: str) -> Tuple[int, numpy.ndarray]:
+  data = row.split(',')
+  label, pixels = int(data[0]), data[1:]
+  pixels = [int(pixel) for pixel in pixels]
+  # the trained model accepts the input of shape 28x28
+  pixels = numpy.array(pixels).reshape((28, 28, 1))
+  return label, pixels
+
+
+class PostProcessor(beam.DoFn):
+  """Process the PredictionResult to get the predicted label.
+  Returns a comma separated string with true label and predicted label.
+  """
+  def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]:
+    label, prediction_result = element
+    prediction = numpy.argmax(prediction_result.inference, axis=0)
+    yield '{},{}'.format(label, prediction)
+
+
+def parse_known_args(argv):
+  """Parses args for the workflow."""
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--input',
+      dest='input',
+      required=True,
+      help='text file with comma separated int values.')
+  parser.add_argument(
+      '--output',
+      dest='output',
+      required=True,
+      help='Path to save output predictions.')
+  parser.add_argument(
+      '--model_path',
+      dest='model_path',
+      required=True,
+      help='Path to load the Tensorflow model for Inference.')
+  return parser.parse_known_args(argv)
+
+
+def run(
+    argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult:
+  """
+  Args:
+    argv: Command line arguments defined for this example.
+    save_main_session: Used for internal testing.
+    test_pipeline: Used for internal testing.
+  """
+  known_args, pipeline_args = parse_known_args(argv)
+  pipeline_options = PipelineOptions(pipeline_args)
+  pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
+
+  # In this example we pass keyed inputs to RunInference transform.
+  # Therefore, we use KeyedModelHandler wrapper over TFModelHandlerNumpy.
+  model_loader = KeyedModelHandler(
+      TFModelHandlerNumpy(
+          model_uri=known_args.model_path, model_type=ModelType.SAVED_MODEL))
+
+  pipeline = test_pipeline
+  if not test_pipeline:
+    pipeline = beam.Pipeline(options=pipeline_options)
+
+  label_pixel_tuple = (
+      pipeline
+      | "ReadFromInput" >> beam.io.ReadFromText(known_args.input)
+      | "PreProcessInputs" >> beam.Map(process_input))
+
+  predictions = (
+      label_pixel_tuple
+      | "RunInference" >> RunInference(model_loader)
+      | "PostProcessOutputs" >> beam.ParDo(PostProcessor()))
+
+  _ = predictions | "WriteOutput" >> beam.io.WriteToText(
+      known_args.output, shard_name_template='', append_trailing_newlines=True)
+
+  result = pipeline.run()
+  result.wait_until_finish()
+  return result
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  run()
diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
new file mode 100644
index 00000000000..ee33c53cadb
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
@@ -0,0 +1,246 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+import enum
+import sys
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Union
+
+import numpy
+
+import tensorflow as tf
+import tensorflow_hub as hub
+from apache_beam.ml.inference import utils
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+
+__all__ = [
+    'TFModelHandlerNumpy',
+    'TFModelHandlerTensor',
+]
+
+TensorInferenceFn = Callable[[
+    tf.Module,
+    Sequence[Union[numpy.ndarray, tf.Tensor]],
+    Dict[str, Any],
+    Optional[str]
+],
+                             Iterable[PredictionResult]]
+
+
+class ModelType(enum.Enum):
+  """Defines how a model file should be loaded."""
+  SAVED_MODEL = 1
+
+
+def _load_model(model_uri, model_type):
+  if model_type == ModelType.SAVED_MODEL:
+    return tf.keras.models.load_model(hub.resolve(model_uri))
+  else:
+    raise AssertionError('Unsupported model type for loading.')
+
+
+def default_numpy_inference_fn(
+    model: tf.Module,
+    batch: Sequence[numpy.ndarray],
+    inference_args: Dict[str, Any],
+    model_id: Optional[str] = None) -> Iterable[PredictionResult]:
+  vectorized_batch = numpy.stack(batch, axis=0)
+  predictions = model(vectorized_batch, **inference_args)
+  return utils._convert_to_result(batch, predictions, model_id)
+
+
+def default_tensor_inference_fn(
+    model: tf.Module,
+    batch: Sequence[tf.Tensor],
+    inference_args: Dict[str, Any],
+    model_id: Optional[str] = None) -> Iterable[PredictionResult]:
+  vectorized_batch = tf.stack(batch, axis=0)
+  predictions = model(vectorized_batch, **inference_args)
+  return utils._convert_to_result(batch, predictions, model_id)
+
+
+class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
+                                       PredictionResult,
+                                       tf.Module]):
+  def __init__(
+      self,
+      model_uri: str,
+      model_type: ModelType = ModelType.SAVED_MODEL,
+      *,
+      inference_fn: TensorInferenceFn = default_numpy_inference_fn):
+    """Implementation of the ModelHandler interface for Tensorflow.
+
+    Example Usage::
+
+      pcoll | RunInference(TFModelHandlerNumpy(model_uri="my_uri"))
+
+    See https://www.tensorflow.org/tutorials/keras/save_and_load for details.
+
+    Args:
+        model_uri (str): path to the trained model.
+        model_type: type of model to be loaded. Defaults to SAVED_MODEL.
+        inference_fn: inference function to use during RunInference.
+          Defaults to default_numpy_inference_fn.
+
+    **Supported Versions:** RunInference APIs in Apache Beam have been tested
+    with Tensorflow 2.9, 2.10, 2.11.
+    """
+    self._model_uri = model_uri
+    self._model_type = model_type
+    self._inference_fn = inference_fn
+
+  def load_model(self) -> tf.Module:
+    """Loads and initializes a Tensorflow model for processing."""
+    return _load_model(self._model_uri, self._model_type)
+
+  def update_model_path(self, model_path: Optional[str] = None):
+    self._model_uri = model_path if model_path else self._model_uri
+
+  def run_inference(
+      self,
+      batch: Sequence[numpy.ndarray],
+      model: tf.Module,
+      inference_args: Optional[Dict[str, Any]] = None
+  ) -> Iterable[PredictionResult]:
+    """
+    Runs inferences on a batch of numpy array and returns an Iterable of
+    numpy array Predictions.
+
+    This method stacks the n-dimensional numpy array in a vectorized format to
+    optimize the inference call.
+
+    Args:
+      batch: A sequence of numpy nd-array. These should be batchable, as this
+        method will call `numpy.stack()` and pass in batched numpy nd-array
+        with dimensions (batch_size, n_features, etc.) into the model's
+        predict() function.
+      model: A Tensorflow model.
+      inference_args: any additional arguments for an inference.
+
+    Returns:
+      An Iterable of type PredictionResult.
+    """
+    inference_args = {} if not inference_args else inference_args
+    return self._inference_fn(model, batch, inference_args, self._model_uri)
+
+  def get_num_bytes(self, batch: Sequence[numpy.ndarray]) -> int:
+    """
+    Returns:
+      The number of bytes of data for a batch of numpy arrays.
+    """
+    return sum(sys.getsizeof(element) for element in batch)
+
+  def get_metrics_namespace(self) -> str:
+    """
+    Returns:
+       A namespace for metrics collected by the RunInference transform.
+    """
+    return 'BeamML_TF_Numpy'
+
+  def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+    pass
+
+
+class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult,
+                                        tf.Module]):
+  def __init__(
+      self,
+      model_uri: str,
+      model_type: ModelType = ModelType.SAVED_MODEL,
+      *,
+      inference_fn: TensorInferenceFn = default_tensor_inference_fn):
+    """Implementation of the ModelHandler interface for Tensorflow.
+
+    Example Usage::
+
+      pcoll | RunInference(TFModelHandlerTensor(model_uri="my_uri"))
+
+    See https://www.tensorflow.org/tutorials/keras/save_and_load for details.
+
+    Args:
+        model_uri (str): path to the trained model.
+        model_type: type of model to be loaded.
+          Defaults to SAVED_MODEL.
+        inference_fn: inference function to use during RunInference.
+          Defaults to default_numpy_inference_fn.
+
+    **Supported Versions:** RunInference APIs in Apache Beam have been tested
+    with Tensorflow 2.11.
+    """
+    self._model_uri = model_uri
+    self._model_type = model_type
+    self._inference_fn = inference_fn
+
+  def load_model(self) -> tf.Module:
+    """Loads and initializes a tensorflow model for processing."""
+    return _load_model(self._model_uri, self._model_type)
+
+  def update_model_path(self, model_path: Optional[str] = None):
+    self._model_uri = model_path if model_path else self._model_uri
+
+  def run_inference(
+      self,
+      batch: Sequence[tf.Tensor],
+      model: tf.Module,
+      inference_args: Optional[Dict[str, Any]] = None
+  ) -> Iterable[PredictionResult]:
+    """
+    Runs inferences on a batch of tf.Tensor and returns an Iterable of
+    Tensor Predictions.
+
+    This method stacks the list of Tensors in a vectorized format to optimize
+    the inference call.
+
+    Args:
+      batch: A sequence of Tensors. These Tensors should be batchable, as this
+        method will call `tf.stack()` and pass in batched Tensors with
+        dimensions (batch_size, n_features, etc.) into the model's predict()
+        function.
+      model: A Tensorflow model.
+      inference_args: Non-batchable arguments required as inputs to the model's
+        forward() function. Unlike Tensors in `batch`, these parameters will
+        not be dynamically batched
+    Returns:
+      An Iterable of type PredictionResult.
+    """
+    inference_args = {} if not inference_args else inference_args
+    return self._inference_fn(model, batch, inference_args, self._model_uri)
+
+  def get_num_bytes(self, batch: Sequence[tf.Tensor]) -> int:
+    """
+    Returns:
+      The number of bytes of data for a batch of Tensors.
+    """
+    return sum(sys.getsizeof(element) for element in batch)
+
+  def get_metrics_namespace(self) -> str:
+    """
+    Returns:
+       A namespace for metrics collected by the RunInference transform.
+    """
+    return 'BeamML_TF_Tensor'
+
+  def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+    pass
diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py
new file mode 100644
index 00000000000..7b4b13ce2e1
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py
@@ -0,0 +1,114 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""End-to-End test for Tensorflow Inference"""
+
+import logging
+import unittest
+import uuid
+
+import pytest
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.testing.test_pipeline import TestPipeline
+
+# pylint: disable=ungrouped-imports
+try:
+  import tensorflow as tf
+  from apache_beam.examples.inference import tensorflow_imagenet_segmentation
+  from apache_beam.examples.inference import tensorflow_mnist_classification
+except ImportError as e:
+  tf = None
+
+
+def process_outputs(filepath):
+  with FileSystems().open(filepath) as f:
+    lines = f.readlines()
+  lines = [l.decode('utf-8').strip('\n') for l in lines]
+  return lines
+
+
+@unittest.skipIf(
+    tf is None, 'Missing dependencies. '
+    'Test depends on tensorflow')
+@pytest.mark.uses_tf
+@pytest.mark.it_postcommit
+class TensorflowInference(unittest.TestCase):
+  def test_tf_mnist_classification(self):
+    test_pipeline = TestPipeline(is_integration_test=True)
+    input_file = 'gs://apache-beam-ml/testing/inputs/it_mnist_data.csv'
+    output_file_dir = 'apache-beam-ml/testing/outputs'
+    output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt'])
+    model_path = 'apache-beam-ml/models/tensorflow/mnist/'
+    extra_opts = {
+        'input': input_file,
+        'output': output_file,
+        'model_path': model_path,
+    }
+    tensorflow_mnist_classification.run(
+        test_pipeline.get_full_options_as_args(**extra_opts),
+        save_main_session=False)
+    self.assertEqual(FileSystems().exists(output_file), True)
+
+    expected_output_filepath = 'gs://apache-beam-ml/testing/expected_outputs/test_sklearn_mnist_classification_actuals.txt'  # pylint: disable=line-too-long
+    expected_outputs = process_outputs(expected_output_filepath)
+    predicted_outputs = process_outputs(output_file)
+    self.assertEqual(len(expected_outputs), len(predicted_outputs))
+
+    predictions_dict = {}
+    for i in range(len(predicted_outputs)):
+      true_label, prediction = predicted_outputs[i].split(',')
+      predictions_dict[true_label] = prediction
+
+    for i in range(len(expected_outputs)):
+      true_label, expected_prediction = expected_outputs[i].split(',')
+      self.assertEqual(predictions_dict[true_label], expected_prediction)
+
+  def test_tf_imagenet_image_segmentation(self):
+    test_pipeline = TestPipeline(is_integration_test=True)
+    input_file = (
+        'gs://apache-beam-ml/testing/inputs/it_imagenet_input_labels.txt')
+    image_dir = (
+        'https://storage.googleapis.com/download.tensorflow.org/example_images/'
+    )
+    output_file_dir = 'apache-beam-ml/testing/outputs'
+    output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt'])
+    model_path = (
+        'https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4')
+    extra_opts = {
+        'input': input_file,
+        'output': output_file,
+        'model_path': model_path,
+        'image_dir': image_dir
+    }
+    tensorflow_imagenet_segmentation.run(
+        test_pipeline.get_full_options_as_args(**extra_opts),
+        save_main_session=False)
+    self.assertEqual(FileSystems().exists(output_file), True)
+
+    expected_output_filepath = 'gs://apache-beam-ml/testing/expected_outputs/test_tf_imagenet_image_segmentation.txt'  # pylint: disable=line-too-long
+    expected_outputs = process_outputs(expected_output_filepath)
+    predicted_outputs = process_outputs(output_file)
+    self.assertEqual(len(expected_outputs), len(predicted_outputs))
+
+    for true_label, predicted_label in zip(expected_outputs, predicted_outputs):
+      self.assertEqual(true_label, predicted_label)
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.DEBUG)
+  unittest.main()
diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py
new file mode 100644
index 00000000000..842de7fe611
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+import unittest
+
+import numpy
+import pytest
+
+try:
+  import tensorflow as tf
+  from apache_beam.ml.inference.sklearn_inference_test import _compare_prediction_result
+  from apache_beam.ml.inference.base import KeyedModelHandler, PredictionResult
+  from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerNumpy, TFModelHandlerTensor
+except ImportError:
+  raise unittest.SkipTest('Tensorflow dependencies are not installed')
+
+
+class FakeTFNumpyModel:
+  def predict(self, input: numpy.ndarray):
+    return numpy.multiply(input, 10)
+
+
+class FakeTFTensorModel:
+  def predict(self, input: tf.Tensor, add=False):
+    if add:
+      return tf.math.add(tf.math.multiply(input, 10), 10)
+    return tf.math.multiply(input, 10)
+
+
+def _compare_tensor_prediction_result(x, y):
+  return tf.math.equal(x.inference, y.inference)
+
+
+@pytest.mark.uses_tf
+class TFRunInferenceTest(unittest.TestCase):
+  def test_predict_numpy(self):
+    fake_model = FakeTFNumpyModel()
+    inference_runner = TFModelHandlerNumpy(model_uri='unused')
+    batched_examples = [numpy.array([1]), numpy.array([10]), numpy.array([100])]
+    expected_predictions = [
+        PredictionResult(numpy.array([1]), 10),
+        PredictionResult(numpy.array([10]), 100),
+        PredictionResult(numpy.array([100]), 1000)
+    ]
+    inferences = inference_runner.run_inference(batched_examples, fake_model)
+    for actual, expected in zip(inferences, expected_predictions):
+      self.assertTrue(_compare_prediction_result(actual, expected))
+
+  def test_predict_tensor(self):
+    fake_model = FakeTFTensorModel()
+    inference_runner = TFModelHandlerTensor(model_uri='unused')
+    batched_examples = [
+        tf.convert_to_tensor(numpy.array([1])),
+        tf.convert_to_tensor(numpy.array([10])),
+        tf.convert_to_tensor(numpy.array([100])),
+    ]
+    expected_predictions = [
+        PredictionResult(ex, pred) for ex,
+        pred in zip(
+            batched_examples,
+            [tf.math.multiply(n, 10) for n in batched_examples])
+    ]
+
+    inferences = inference_runner.run_inference(batched_examples, fake_model)
+    for actual, expected in zip(inferences, expected_predictions):
+      self.assertTrue(_compare_tensor_prediction_result(actual, expected))
+
+  def test_predict_tensor_with_args(self):
+    fake_model = FakeTFTensorModel()
+    inference_runner = TFModelHandlerTensor(model_uri='unused')
+    batched_examples = [
+        tf.convert_to_tensor(numpy.array([1])),
+        tf.convert_to_tensor(numpy.array([10])),
+        tf.convert_to_tensor(numpy.array([100])),
+    ]
+    expected_predictions = [
+        PredictionResult(ex, pred) for ex,
+        pred in zip(
+            batched_examples, [
+                tf.math.add(tf.math.multiply(n, 10), 10)
+                for n in batched_examples
+            ])
+    ]
+
+    inferences = inference_runner.run_inference(
+        batched_examples, fake_model, inference_args={"add": True})
+    for actual, expected in zip(inferences, expected_predictions):
+      self.assertTrue(_compare_tensor_prediction_result(actual, expected))
+
+  def test_predict_keyed_numpy(self):
+    fake_model = FakeTFNumpyModel()
+    inference_runner = KeyedModelHandler(
+        TFModelHandlerNumpy(model_uri='unused'))
+    batched_examples = [
+        ('k1', numpy.array([1], dtype=numpy.int64)),
+        ('k2', numpy.array([10], dtype=numpy.int64)),
+        ('k3', numpy.array([100], dtype=numpy.int64)),
+    ]
+    expected_predictions = [
+        (ex[0], PredictionResult(ex[1], pred)) for ex,
+        pred in zip(
+            batched_examples,
+            [numpy.multiply(n[1], 10) for n in batched_examples])
+    ]
+    inferences = inference_runner.run_inference(batched_examples, fake_model)
+    for actual, expected in zip(inferences, expected_predictions):
+      self.assertTrue(_compare_prediction_result(actual[1], expected[1]))
+
+  def test_predict_keyed_tensor(self):
+    fake_model = FakeTFTensorModel()
+    inference_runner = KeyedModelHandler(
+        TFModelHandlerTensor(model_uri='unused'))
+    batched_examples = [
+        ('k1', tf.convert_to_tensor(numpy.array([1]))),
+        ('k2', tf.convert_to_tensor(numpy.array([10]))),
+        ('k3', tf.convert_to_tensor(numpy.array([100]))),
+    ]
+    expected_predictions = [
+        (ex[0], PredictionResult(ex[1], pred)) for ex,
+        pred in zip(
+            batched_examples,
+            [tf.math.multiply(n[1], 10) for n in batched_examples])
+    ]
+    inferences = inference_runner.run_inference(batched_examples, fake_model)
+    for actual, expected in zip(inferences, expected_predictions):
+      self.assertTrue(_compare_tensor_prediction_result(actual[1], expected[1]))
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_tests_requirements.txt b/sdks/python/apache_beam/ml/inference/tensorflow_tests_requirements.txt
new file mode 100644
index 00000000000..8a9deba61dd
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_tests_requirements.txt
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+tensorflow>=1.0.0
+tensorflow_hub>=0.10.0
+Pillow>=9.0.0
+
diff --git a/sdks/python/pytest.ini b/sdks/python/pytest.ini
index 560ef2f187a..2733b2511d0 100644
--- a/sdks/python/pytest.ini
+++ b/sdks/python/pytest.ini
@@ -53,6 +53,7 @@ markers =
     uses_tensorflow: tests that utilize tensorflow in some way
     uses_tft: tests that utilizes tensorflow transforms in some way.
     uses_onnx: tests that utilizes onnx in some way.
+    uses_tf: tests that utilize tensorflow
 
 # Default timeout intended for unit tests.
 # If certain tests need a different value, please see the docs on how to
diff --git a/sdks/python/test-suites/direct/common.gradle b/sdks/python/test-suites/direct/common.gradle
index 9281355ad65..80e0bf052e5 100644
--- a/sdks/python/test-suites/direct/common.gradle
+++ b/sdks/python/test-suites/direct/common.gradle
@@ -281,11 +281,39 @@ task tfxInferenceTest {
     }
 }
 
+// TensorFlow RunInference IT tests
+task tensorflowInferenceTest {
+  dependsOn 'installGcpTest'
+  dependsOn ':sdks:python:sdist'
+  def requirementsFile = "${rootDir}/sdks/python/apache_beam/ml/inference/tensorflow_tests_requirements.txt"
+  doFirst {
+      exec {
+        executable 'sh'
+        args '-c', ". ${envdir}/bin/activate && pip install -r $requirementsFile"
+      }
+    }
+  doLast {
+      def testOpts = basicTestOpts
+      def argMap = [
+          "test_opts": testOpts,
+          "suite": "postCommitIT-direct-py${pythonVersionSuffix}",
+          "collect": "uses_tf and it_postcommit" ,
+          "runner": "TestDirectRunner"
+      ]
+      def cmdArgs = mapToArgString(argMap)
+      exec {
+        executable 'sh'
+        args '-c', ". ${envdir}/bin/activate && ${runScriptsDir}/run_integration_test.sh $cmdArgs"
+      }
+    }
+}
+
 // Add all the RunInference framework IT tests to this gradle task that runs on Direct Runner Post commit suite.
 project.tasks.register("inferencePostCommitIT") {
   dependsOn = [
   'torchInferenceTest',
   'sklearnInferenceTest',
-  'tfxInferenceTest'
+  'tfxInferenceTest',
+  'tensorflowInferenceTest'
   ]
 }
diff --git a/sdks/python/test-suites/tox/py38/build.gradle b/sdks/python/test-suites/tox/py38/build.gradle
index ea803faabc5..7d582bd89c1 100644
--- a/sdks/python/test-suites/tox/py38/build.gradle
+++ b/sdks/python/test-suites/tox/py38/build.gradle
@@ -106,6 +106,18 @@ preCommitPyCoverage.dependsOn "testPy38pytorch-113"
 toxTask "testPy38onnx-113", "py38-onnx-113", "${posargs}"
 test.dependsOn "testPy38onnx-113"
 preCommitPyCoverage.dependsOn "testPy38onnx-113"
+// Create a test task for each minor version of tensorflow
+toxTask "testPy38tensorflow-29", "py38-tensorflow-29", "${posargs}"
+test.dependsOn "testPy38tensorflow-29"
+preCommitPyCoverage.dependsOn "testPy38tensorflow-29"
+
+toxTask "testPy38tensorflow-210", "py38-tensorflow-210", "${posargs}"
+test.dependsOn "testPy38tensorflow-210"
+preCommitPyCoverage.dependsOn "testPy38tensorflow-210"
+
+toxTask "testPy38tensorflow-211", "py38-tensorflow-211", "${posargs}"
+test.dependsOn "testPy38tensorflow-211"
+preCommitPyCoverage.dependsOn "testPy38tensorflow-211"
 
 toxTask "whitespacelint", "whitespacelint", "${posargs}"
 
diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini
index 5504430823f..5b7e10bf12a 100644
--- a/sdks/python/tox.ini
+++ b/sdks/python/tox.ini
@@ -153,6 +153,8 @@ deps =
   torch
   onnxruntime
   onnx
+  tensorflow
+  tensorflow_hub
 commands =
   time {toxinidir}/scripts/generate_pydoc.sh
 
@@ -341,3 +343,18 @@ commands =
   /bin/sh -c "pip freeze | grep -E onnx"
   # Run all ONNX unit tests
   pytest -o junit_suite_name={envname} --junitxml=pytest_{envname}.xml -n 6 -m uses_onnx {posargs}
+  
+[testenv:py{37,38,39,310}-tf-{211}]
+[testenv:py{37,38,39,310}-tensorflow-{29,210,211}]
+deps =
+  -r build-requirements.txt
+  29: tensorflow>=2.9.0,<2.10.0
+  210: tensorflow>=2.10.0,<2.11.0
+  211: tensorflow>=2.11.0,<2.12.0
+extras = test,gcp
+commands =
+  # Log tensorflow version for debugging
+  /bin/sh -c "pip freeze | grep -E tensorflow"
+  # Run all Tensorflow unit tests
+  # Allow exit code 5 (no tests run) so that we can run this command safely on arbitrary subdirectories.
+  /bin/sh -c 'pytest -o junit_suite_name={envname} --junitxml=pytest_{envname}.xml -n 6 -m uses_tf {posargs}; ret=$?; [ $ret = 5 ] && exit 0 || exit $ret'