You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ri...@apache.org on 2023/07/24 19:28:41 UTC

[beam] branch master updated: [Python] Implemented Hugging Face Model Handler (#26632)

This is an automated email from the ASF dual-hosted git repository.

riteshghorse pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 193b72057ea [Python] Implemented Hugging Face Model Handler (#26632)
193b72057ea is described below

commit 193b72057ea12ce3c7e1c4b8eb5046c871a78750
Author: Ritesh Ghorse <ri...@gmail.com>
AuthorDate: Mon Jul 24 15:28:32 2023 -0400

    [Python] Implemented Hugging Face Model Handler (#26632)
    
    * automodel first pass
    
    * new model
    
    * updated model handler api
    
    * add model_class param
    
    * update doc comments
    
    * updated integration test and example
    
    * unit test, modified params
    
    * add test setup for hugging face tests
    
    * fix lints
    
    * fix import order
    
    * refactor, doc, lints
    
    * refactor, doc comments
    
    * change test file
    
    * update types
    
    * update tox, doc, lints
    
    * fix lints
    
    * pr type
    
    * update gpu warnings
    
    * fix pydoc
    
    * update typos, refactor
    
    * fix docstrings
    
    * refactor, doc, lints
    
    * pydoc
    
    * fix pydoc
    
    * updates to keyed model handler
    
    * pylints
---
 .../apache_beam/examples/inference/README.md       |  73 +++
 .../inference/huggingface_language_modeling.py     | 177 ++++++++
 .../ml/inference/huggingface_inference.py          | 488 +++++++++++++++++++++
 .../ml/inference/huggingface_inference_it_test.py  |  80 ++++
 .../ml/inference/huggingface_inference_test.py     | 136 ++++++
 .../inference/huggingface_tests_requirements.txt   |  20 +
 sdks/python/pytest.ini                             |  13 +-
 sdks/python/test-suites/direct/common.gradle       |  30 +-
 sdks/python/test-suites/tox/py38/build.gradle      |  13 +
 sdks/python/tox.ini                                |  19 +
 10 files changed, 1042 insertions(+), 7 deletions(-)

diff --git a/sdks/python/apache_beam/examples/inference/README.md b/sdks/python/apache_beam/examples/inference/README.md
index 326ec4b4a09..1653f3a9699 100644
--- a/sdks/python/apache_beam/examples/inference/README.md
+++ b/sdks/python/apache_beam/examples/inference/README.md
@@ -42,6 +42,7 @@ The RunInference API supports the Tensorflow framework. To use Tensorflow locall
 pip install tensorflow==2.12.0
 ```
 
+
 ### PyTorch dependencies
 
 The following installation requirements are for the files used in these examples.
@@ -65,6 +66,21 @@ For installation of the `torch` dependency on a distributed runner such as Dataf
 [PyPI dependency instructions](https://beam.apache.org/documentation/sdks/python-pipeline-dependencies/#pypi-dependencies).
 
 
+### Transformers dependencies
+
+The following installation requirement is for the Hugging Face model handler examples.
+
+The RunInference API supports loading models from the Hugging Face Hub. To use it, first install `transformers`.
+```
+pip install transformers==4.30.0
+```
+Additional dependicies for PyTorch and TensorFlow may need to be installed separately:
+```
+pip install tensorflow==2.12.0
+pip install torch==1.10.0
+```
+
+
 ### TensorRT dependencies
 
 The RunInference API supports TensorRT SDK for high-performance deep learning inference with NVIDIA GPUs.
@@ -687,3 +703,60 @@ MilkQualityAggregation(bad_quality_measurements=6, medium_quality_measurements=4
 MilkQualityAggregation(bad_quality_measurements=3, medium_quality_measurements=3, high_quality_measurements=3)
 MilkQualityAggregation(bad_quality_measurements=1, medium_quality_measurements=2, high_quality_measurements=1)
 ```
+
+---
+## Language modeling with Hugging Face Hub
+
+[`huggingface_language_modeling.py`](./huggingface_language_modeling.py) contains an implementation for a RunInference pipeline that performs masked language modeling (that is, decoding a masked token in a sentence) using the `AutoModelForMaskedLM` architecture from Hugging Face.
+
+The pipeline reads sentences, performs basic preprocessing to convert the last word into a `<mask>` token, passes the masked sentence to the Hugging Face implementation of RunInference, and then writes the predictions to a text file.
+
+### Dataset and model for language modeling
+
+To use this transform, you need a dataset and model for language modeling.
+
+1. Choose a checkpoint to load from Hugging Face Hub, eg:[MaskedLanguageModel](https://huggingface.co/stevhliu/my_awesome_eli5_mlm_model).
+2. (Optional) Create a file named `SENTENCES.txt` that contains sentences to feed into the model. The content of the file should be similar to the following example:
+```
+The capital of France is Paris .
+He looked up and saw the sun and stars .
+...
+```
+
+### Running `huggingface_language_modeling.py`
+
+To run the language modeling pipeline locally, use the following command:
+```sh
+python -m apache_beam.examples.inference.huggingface_language_modeling \
+  --input SENTENCES \
+  --output OUTPUT \
+  --model_name REPOSITORY_ID
+```
+The `input` argument is optional. If none is provided, it will run the pipeline with some
+example sentences.
+
+For example, if you've followed the naming conventions recommended above:
+```sh
+python -m apache_beam.examples.inference.huggingface_language_modeling \
+  --input SENTENCES.txt \
+  --output predictions.csv \
+  --model_name "stevhliu/my_awesome_eli5_mlm_model"
+```
+Or, using the default example sentences:
+```sh
+python -m apache_beam.examples.inference.huggingface_language_modeling \
+  --output predictions.csv \
+  --model_name "stevhliu/my_awesome_eli5_mlm_model"
+```
+
+This writes the output to the `predictions.csv` with contents like:
+```
+The capital of France is Paris .;paris
+He looked up and saw the sun and stars .;moon
+...
+```
+Each line has data separated by a semicolon ";".
+The first item is the input sentence. The model masks the last word and tries to predict it;
+the second item is the word that the model predicts for the mask.
+
+---
\ No newline at end of file
diff --git a/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py b/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py
new file mode 100644
index 00000000000..f6cb3de72b7
--- /dev/null
+++ b/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py
@@ -0,0 +1,177 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A pipeline that uses RunInference to perform Language Modeling with
+masked language model from Hugging Face.
+
+This pipeline takes sentences from a custom text file, converts the last word
+of the sentence into a <mask> token, and then uses the AutoModelForMaskedLM from
+Hugging Face to predict the best word for the masked token given all the words
+already in the sentence. The pipeline then writes the prediction to an output
+file in which users can then compare against the original sentence.
+"""
+
+import argparse
+import logging
+from typing import Dict
+from typing import Iterable
+from typing import Iterator
+from typing import Tuple
+
+import apache_beam as beam
+import torch
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.huggingface_inference import HuggingFaceModelHandlerKeyedTensor
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from apache_beam.runners.runner import PipelineResult
+from transformers import AutoModelForMaskedLM
+from transformers import AutoTokenizer
+
+
+def add_mask_to_last_word(text: str) -> Tuple[str, str]:
+  text_list = text.split()
+  return text, ' '.join(text_list[:-2] + ['<mask>', text_list[-1]])
+
+
+def tokenize_sentence(
+    text_and_mask: Tuple[str, str],
+    tokenizer: AutoTokenizer) -> Tuple[str, Dict[str, torch.Tensor]]:
+  text, masked_text = text_and_mask
+  tokenized_sentence = tokenizer.encode_plus(masked_text, return_tensors="pt")
+
+  # Workaround to manually remove batch dim until we have the feature to
+  # add optional batching flag.
+  # TODO(https://github.com/apache/beam/issues/21863): Remove once optional
+  # batching flag added
+  return text, {
+      k: torch.squeeze(v)
+      for k, v in dict(tokenized_sentence).items()
+  }
+
+
+def filter_empty_lines(text: str) -> Iterator[str]:
+  if len(text.strip()) > 0:
+    yield text
+
+
+class PostProcessor(beam.DoFn):
+  """Processes the PredictionResult to get the predicted word.
+
+  The logits are the output of the Model. We can get the word with the highest
+  probability of being a candidate replacement word by taking the argmax.
+  """
+  def __init__(self, tokenizer: AutoTokenizer):
+    super().__init__()
+    self.tokenizer = tokenizer
+
+  def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]:
+    text, prediction_result = element
+    inputs = prediction_result.example
+    logits = prediction_result.inference['logits']
+    mask_token_index = torch.where(
+        inputs["input_ids"] == self.tokenizer.mask_token_id)[0]
+    predicted_token_id = logits[mask_token_index].argmax(axis=-1)
+    decoded_word = self.tokenizer.decode(predicted_token_id)
+    yield text + ';' + decoded_word
+
+
+def parse_known_args(argv):
+  """Parses args for the workflow."""
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--input',
+      dest='input',
+      help='Path to the text file containing sentences.')
+  parser.add_argument(
+      '--output',
+      dest='output',
+      required=True,
+      help='Path of file in which to save the output predictions.')
+  parser.add_argument(
+      '--model_name',
+      dest='model_name',
+      required=True,
+      help='bert uncased model. This can be base model or large model')
+  parser.add_argument(
+      '--model_class',
+      dest='model_class',
+      default=AutoModelForMaskedLM,
+      help="Name of the model from Hugging Face")
+  return parser.parse_known_args(argv)
+
+
+def run(
+    argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult:
+  """
+  Args:
+    argv: Command line arguments defined for this example.
+    save_main_session: Used for internal testing.
+    test_pipeline: Used for internal testing.
+  """
+  known_args, pipeline_args = parse_known_args(argv)
+  pipeline_options = PipelineOptions(pipeline_args)
+  pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
+
+  pipeline = test_pipeline
+  if not test_pipeline:
+    pipeline = beam.Pipeline(options=pipeline_options)
+
+  tokenizer = AutoTokenizer.from_pretrained(known_args.model_name)
+
+  model_handler = HuggingFaceModelHandlerKeyedTensor(
+      model_uri=known_args.model_name,
+      model_class=known_args.model_class,
+      framework='pt',
+      max_batch_size=1)
+  if not known_args.input:
+    text = (
+        pipeline | 'CreateSentences' >> beam.Create([
+            'The capital of France is Paris .',
+            'It is raining cats and dogs .',
+            'Today is Monday and tomorrow is Tuesday .',
+            'There are 5 coconuts on this palm tree .',
+            'The strongest person in the world is not famous .',
+            'The secret ingredient to his wonderful life was gratitude .',
+            'The biggest animal in the world is the whale .',
+        ]))
+  else:
+    text = (
+        pipeline | 'ReadSentences' >> beam.io.ReadFromText(known_args.input))
+  text_and_tokenized_text_tuple = (
+      text
+      | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines)
+      | 'AddMask' >> beam.Map(add_mask_to_last_word)
+      |
+      'TokenizeSentence' >> beam.Map(lambda x: tokenize_sentence(x, tokenizer)))
+  output = (
+      text_and_tokenized_text_tuple
+      | 'RunInference' >> RunInference(KeyedModelHandler(model_handler))
+      | 'ProcessOutput' >> beam.ParDo(PostProcessor(tokenizer=tokenizer)))
+  _ = output | "WriteOutput" >> beam.io.WriteToText(
+      known_args.output, shard_name_template='', append_trailing_newlines=True)
+
+  result = pipeline.run()
+  result.wait_until_finish()
+  return result
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  run()
diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference.py b/sdks/python/apache_beam/ml/inference/huggingface_inference.py
new file mode 100644
index 00000000000..35c3a1686c7
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/huggingface_inference.py
@@ -0,0 +1,488 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+import logging
+import sys
+from collections import defaultdict
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Union
+
+import tensorflow as tf
+import torch
+from apache_beam.ml.inference import utils
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.pytorch_inference import _convert_to_device
+from transformers import AutoModel
+from transformers import TFAutoModel
+
+_LOGGER = logging.getLogger(__name__)
+
+__all__ = [
+    "HuggingFaceModelHandlerTensor",
+    "HuggingFaceModelHandlerKeyedTensor",
+]
+
+TensorInferenceFn = Callable[[
+    Sequence[Union[torch.Tensor, tf.Tensor]],
+    Union[AutoModel, TFAutoModel],
+    str,
+    Optional[Dict[str, Any]],
+    Optional[str],
+],
+                             Iterable[PredictionResult],
+                             ]
+
+KeyedTensorInferenceFn = Callable[[
+    Sequence[Dict[str, Union[torch.Tensor, tf.Tensor]]],
+    Union[AutoModel, TFAutoModel],
+    str,
+    Optional[Dict[str, Any]],
+    Optional[str],
+],
+                                  Iterable[PredictionResult],
+                                  ]
+
+
+def _validate_constructor_args(model_uri, model_class):
+  message = (
+      "Please provide both model class and model uri to load the model."
+      "Got params as model_uri={model_uri} and "
+      "model_class={model_class}.")
+  if not model_uri and not model_class:
+    raise RuntimeError(
+        message.format(model_uri=model_uri, model_class=model_class))
+  elif not model_uri:
+    raise RuntimeError(
+        message.format(model_uri=model_uri, model_class=model_class))
+  elif not model_class:
+    raise RuntimeError(
+        message.format(model_uri=model_uri, model_class=model_class))
+
+
+def no_gpu_available_warning():
+  _LOGGER.warning(
+      "HuggingFaceModelHandler specified a 'GPU' device, "
+      "but GPUs are not available. Switching to CPU.")
+
+
+def is_gpu_available_torch():
+  if torch.cuda.is_available():
+    return True
+  else:
+    no_gpu_available_warning()
+    return False
+
+
+def get_device_torch(device):
+  if device == "GPU" and is_gpu_available_torch():
+    return torch.device("cuda")
+  return torch.device("cpu")
+
+
+def is_gpu_available_tensorflow(device):
+  gpu_devices = tf.config.list_physical_devices(device)
+  if len(gpu_devices) == 0:
+    no_gpu_available_warning()
+    return False
+  return True
+
+
+def _run_inference_torch_keyed_tensor(
+    batch: Sequence[Dict[str, torch.Tensor]],
+    model: AutoModel,
+    device,
+    inference_args: Dict[str, Any],
+    model_id: Optional[str] = None) -> Iterable[PredictionResult]:
+  device = get_device_torch(device)
+  key_to_tensor_list = defaultdict(list)
+  # torch.no_grad() mitigates GPU memory issues
+  # https://github.com/apache/beam/issues/22811
+  with torch.no_grad():
+    for example in batch:
+      for key, tensor in example.items():
+        key_to_tensor_list[key].append(tensor)
+    key_to_batched_tensors = {}
+    for key in key_to_tensor_list:
+      batched_tensors = torch.stack(key_to_tensor_list[key])
+      batched_tensors = _convert_to_device(batched_tensors, device)
+      key_to_batched_tensors[key] = batched_tensors
+    predictions = model(**key_to_batched_tensors, **inference_args)
+    return utils._convert_to_result(batch, predictions, model_id)
+
+
+def _run_inference_tensorflow_keyed_tensor(
+    batch: Sequence[Dict[str, tf.Tensor]],
+    model: TFAutoModel,
+    device,
+    inference_args: Dict[str, Any],
+    model_id: Optional[str] = None) -> Iterable[PredictionResult]:
+  if device == "GPU":
+    is_gpu_available_tensorflow(device)
+  key_to_tensor_list = defaultdict(list)
+  for example in batch:
+    for key, tensor in example.items():
+      key_to_tensor_list[key].append(tensor)
+  key_to_batched_tensors = {}
+  for key in key_to_tensor_list:
+    batched_tensors = tf.stack(key_to_tensor_list[key], axis=0)
+    key_to_batched_tensors[key] = batched_tensors
+  predictions = model(**key_to_batched_tensors, **inference_args)
+  return utils._convert_to_result(batch, predictions, model_id)
+
+
+class HuggingFaceModelHandlerKeyedTensor(ModelHandler[Dict[str,
+                                                           Union[tf.Tensor,
+                                                                 torch.Tensor]],
+                                                      PredictionResult,
+                                                      Union[AutoModel,
+                                                            TFAutoModel]]):
+  def __init__(
+      self,
+      model_uri: str,
+      model_class: Union[AutoModel, TFAutoModel],
+      framework: str,
+      device: str = "CPU",
+      *,
+      inference_fn: Optional[Callable[..., Iterable[PredictionResult]]] = None,
+      load_model_args: Optional[Dict[str, Any]] = None,
+      inference_args: Optional[Dict[str, Any]] = None,
+      min_batch_size: Optional[int] = None,
+      max_batch_size: Optional[int] = None,
+      large_model: bool = False,
+      **kwargs):
+    """
+    Implementation of the ModelHandler interface for HuggingFace with
+    Keyed Tensors for PyTorch/Tensorflow backend.
+
+    Example Usage model::
+      pcoll | RunInference(HuggingFaceModelHandlerKeyedTensor(
+        model_uri="bert-base-uncased", model_class=AutoModelForMaskedLM,
+        framework='pt'))
+
+    Args:
+      model_uri (str): path to the pretrained model on the hugging face
+        models hub.
+      model_class: model class to load the repository from model_uri.
+      framework (str): Framework to use for the model. 'tf' for TensorFlow and
+        'pt' for PyTorch.
+      device: For torch tensors, specify device on which you wish to
+        run the model. Defaults to CPU.
+      inference_fn: the inference function to use during RunInference.
+        Default is _run_inference_torch_keyed_tensor or
+        _run_inference_tensorflow_keyed_tensor depending on the input type.
+      load_model_args (Dict[str, Any]): (Optional) Keyword arguments to provide
+        load options while loading models from Hugging Face Hub.
+        Defaults to None.
+      inference_args (Dict[str, Any]): (Optional) Non-batchable arguments
+        required as inputs to the model's inference function. Unlike Tensors
+        in `batch`, these parameters will not be dynamically batched.
+        Defaults to None.
+      min_batch_size: the minimum batch size to use when batching inputs.
+      max_batch_size: the maximum batch size to use when batching inputs.
+      large_model: set to true if your model is large enough to run into
+        memory pressure if you load multiple copies. Given a model that
+        consumes N memory and a machine with W cores and M memory, you should
+        set this to True if N*W > M.
+      kwargs: 'env_vars' can be used to set environment variables
+        before loading the model.
+
+    **Supported Versions:** HuggingFaceModelHandler supports
+    transformers>=4.18.0.
+    """
+    self._model_uri = model_uri
+    self._model_class = model_class
+    self._device = device
+    self._inference_fn = inference_fn
+    self._model_config_args = load_model_args if load_model_args else {}
+    self._inference_args = inference_args if inference_args else {}
+    self._batching_kwargs = {}
+    self._env_vars = kwargs.get("env_vars", {})
+    if min_batch_size is not None:
+      self._batching_kwargs["min_batch_size"] = min_batch_size
+    if max_batch_size is not None:
+      self._batching_kwargs["max_batch_size"] = max_batch_size
+    self._large_model = large_model
+    self._framework = framework
+
+    _validate_constructor_args(
+        model_uri=self._model_uri, model_class=self._model_class)
+
+  def load_model(self):
+    """Loads and initializes the model for processing."""
+    model = self._model_class.from_pretrained(
+        self._model_uri, **self._model_config_args)
+    if self._framework == 'pt':
+      if self._device == "GPU" and is_gpu_available_torch:
+        model.to(torch.device("cuda"))
+    return model
+
+  def run_inference(
+      self,
+      batch: Sequence[Dict[str, Union[tf.Tensor, torch.Tensor]]],
+      model: Union[AutoModel, TFAutoModel],
+      inference_args: Optional[Dict[str, Any]] = None
+  ) -> Iterable[PredictionResult]:
+    """
+    Runs inferences on a batch of Keyed Tensors and returns an Iterable of
+    Tensors Predictions.
+
+    This method stacks the list of Tensors in a vectorized format to optimize
+    the inference call.
+
+    Args:
+      batch: A sequence of Keyed Tensors. These Tensors should be batchable,
+        as this method will call `tf.stack()`/`torch.stack()` and pass in
+        batched Tensors with dimensions (batch_size, n_features, etc.) into
+        the model's predict() function.
+      model: A Tensorflow/PyTorch model.
+      inference_args: Non-batchable arguments required as inputs to the
+        model's inference function. Unlike Tensors in `batch`,
+        these parameters will not be dynamically batched.
+    Returns:
+      An Iterable of type PredictionResult.
+    """
+    inference_args = {} if not inference_args else inference_args
+
+    if self._inference_fn:
+      return self._inference_fn(
+          batch, model, self._device, inference_args, self._model_uri)
+
+    if self._framework == "tf":
+      return _run_inference_tensorflow_keyed_tensor(
+          batch, model, self._device, inference_args, self._model_uri)
+    else:
+      return _run_inference_torch_keyed_tensor(
+          batch, model, self._device, inference_args, self._model_uri)
+
+  def update_model_path(self, model_path: Optional[str] = None):
+    self._model_uri = model_path if model_path else self._model_uri
+
+  def get_num_bytes(
+      self, batch: Sequence[Union[tf.Tensor, torch.Tensor]]) -> int:
+    """
+    Returns:
+      The number of bytes of data for the Tensors batch.
+    """
+    if self._framework == "tf":
+      return sum(sys.getsizeof(element) for element in batch)
+    else:
+      return sum(
+          (el.element_size() for tensor in batch for el in tensor.values()))
+
+  def batch_elements_kwargs(self):
+    return self._batching_kwargs
+
+  def share_model_across_processes(self) -> bool:
+    return self._large_model
+
+  def get_metrics_namespace(self) -> str:
+    """
+    Returns:
+        A namespace for metrics collected by the RunInference transform.
+    """
+    return "BeamML_HuggingFaceModelHandler_KeyedTensor"
+
+
+def _default_inference_fn_torch(
+    batch: Sequence[Union[tf.Tensor, torch.Tensor]],
+    model: Union[AutoModel, TFAutoModel],
+    device,
+    inference_args: Dict[str, Any],
+    model_id: Optional[str] = None) -> Iterable[PredictionResult]:
+  device = get_device_torch(device)
+  # torch.no_grad() mitigates GPU memory issues
+  # https://github.com/apache/beam/issues/22811
+  with torch.no_grad():
+    batched_tensors = torch.stack(batch)
+    batched_tensors = _convert_to_device(batched_tensors, device)
+    predictions = model(batched_tensors, **inference_args)
+  return utils._convert_to_result(batch, predictions, model_id)
+
+
+def _default_inference_fn_tensorflow(
+    batch: Sequence[Union[tf.Tensor, torch.Tensor]],
+    model: Union[AutoModel, TFAutoModel],
+    device,
+    inference_args: Dict[str, Any],
+    model_id: Optional[str] = None) -> Iterable[PredictionResult]:
+  if device == "GPU":
+    is_gpu_available_tensorflow(device)
+  batched_tensors = tf.stack(batch, axis=0)
+  predictions = model(batched_tensors, **inference_args)
+  return utils._convert_to_result(batch, predictions, model_id)
+
+
+class HuggingFaceModelHandlerTensor(ModelHandler[Union[tf.Tensor, torch.Tensor],
+                                                 PredictionResult,
+                                                 Union[AutoModel,
+                                                       TFAutoModel]]):
+  def __init__(
+      self,
+      model_uri: str,
+      model_class: Union[AutoModel, TFAutoModel],
+      device: str = "CPU",
+      *,
+      inference_fn: Optional[Callable[..., Iterable[PredictionResult]]] = None,
+      load_model_args: Optional[Dict[str, Any]] = None,
+      inference_args: Optional[Dict[str, Any]] = None,
+      min_batch_size: Optional[int] = None,
+      max_batch_size: Optional[int] = None,
+      large_model: bool = False,
+      **kwargs):
+    """
+    Implementation of the ModelHandler interface for HuggingFace with
+    Tensors for PyTorch/Tensorflow backend.
+
+    Depending on the type of tensors, the model framework is determined
+    automatically.
+
+    Example Usage model:
+      pcoll | RunInference(HuggingFaceModelHandlerTensor(
+        model_uri="bert-base-uncased", model_class=AutoModelForMaskedLM))
+
+    Args:
+      model_uri (str): path to the pretrained model on the hugging face
+        models hub.
+      model_class: model class to load the repository from model_uri.
+      device: For torch tensors, specify device on which you wish to
+        run the model. Defaults to CPU.
+      inference_fn: the inference function to use during RunInference.
+        Default is _run_inference_torch_keyed_tensor or
+        _run_inference_tensorflow_keyed_tensor depending on the input type.
+      load_model_args (Dict[str, Any]): (Optional) keyword arguments to provide
+        load options while loading models from Hugging Face Hub.
+        Defaults to None.
+      inference_args (Dict[str, Any]): (Optional) Non-batchable arguments
+        required as inputs to the model's inference function. Unlike Tensors
+        in `batch`, these parameters will not be dynamically batched.
+        Defaults to None.
+      min_batch_size: the minimum batch size to use when batching inputs.
+      max_batch_size: the maximum batch size to use when batching inputs.
+      large_model: set to true if your model is large enough to run into
+        memory pressure if you load multiple copies. Given a model that
+        consumes N memory and a machine with W cores and M memory, you should
+        set this to True if N*W > M.
+      kwargs: 'env_vars' can be used to set environment variables
+        before loading the model.
+
+    **Supported Versions:** HuggingFaceModelHandler supports
+    transformers>=4.18.0.
+    """
+    self._model_uri = model_uri
+    self._model_class = model_class
+    self._device = device
+    self._inference_fn = inference_fn
+    self._model_config_args = load_model_args if load_model_args else {}
+    self._inference_args = inference_args if inference_args else {}
+    self._batching_kwargs = {}
+    self._env_vars = kwargs.get("env_vars", {})
+    if min_batch_size is not None:
+      self._batching_kwargs["min_batch_size"] = min_batch_size
+    if max_batch_size is not None:
+      self._batching_kwargs["max_batch_size"] = max_batch_size
+    self._large_model = large_model
+    self._framework = ""
+
+    _validate_constructor_args(
+        model_uri=self._model_uri, model_class=self._model_class)
+
+  def load_model(self):
+    """Loads and initializes the model for processing."""
+    model = self._model_class.from_pretrained(
+        self._model_uri, **self._model_config_args)
+    return model
+
+  def run_inference(
+      self,
+      batch: Sequence[Union[tf.Tensor, torch.Tensor]],
+      model: Union[AutoModel, TFAutoModel],
+      inference_args: Optional[Dict[str, Any]] = None
+  ) -> Iterable[PredictionResult]:
+    """
+    Runs inferences on a batch of Tensors and returns an Iterable of
+    Tensors Predictions.
+
+    This method stacks the list of Tensors in a vectorized format to optimize
+    the inference call.
+
+    Args:
+      batch: A sequence of Tensors. These Tensors should be batchable, as
+        this method will call `tf.stack()`/`torch.stack()` and pass in
+        batched Tensors with dimensions (batch_size, n_features, etc.)
+        into the model's predict() function.
+      model: A Tensorflow/PyTorch model.
+      inference_args (Dict[str, Any]): Non-batchable arguments required as
+        inputs to the model's inference function. Unlike Tensors in `batch`,
+        these parameters will not be dynamically batched.
+
+    Returns:
+      An Iterable of type PredictionResult.
+    """
+    inference_args = {} if not inference_args else inference_args
+    if not self._framework:
+      if isinstance(batch[0], tf.Tensor):
+        self._framework = "tf"
+      else:
+        self._framework = "pt"
+
+    if (self._framework == 'pt' and self._device == "GPU" and
+        is_gpu_available_torch()):
+      model.to(torch.device("cuda"))
+
+    if self._inference_fn:
+      return self._inference_fn(
+          batch, model, inference_args, inference_args, self._model_uri)
+
+    if self._framework == "tf":
+      return _default_inference_fn_tensorflow(
+          batch, model, self._device, inference_args, self._model_uri)
+    else:
+      return _default_inference_fn_torch(
+          batch, model, self._device, inference_args, self._model_uri)
+
+  def get_num_bytes(
+      self, batch: Sequence[Union[tf.Tensor, torch.Tensor]]) -> int:
+    """
+    Returns:
+      The number of bytes of data for the Tensors batch.
+    """
+    if self._framework == "tf":
+      return sum(sys.getsizeof(element) for element in batch)
+    else:
+      return sum(
+          (el.element_size() for tensor in batch for el in tensor.values()))
+
+  def batch_elements_kwargs(self):
+    return self._batching_kwargs
+
+  def share_model_across_processes(self) -> bool:
+    return self._large_model
+
+  def get_metrics_namespace(self) -> str:
+    """
+    Returns:
+        A namespace for metrics collected by the RunInference transform.
+    """
+    return "BeamML_HuggingFaceModelHandler_Tensor"
diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference_it_test.py b/sdks/python/apache_beam/ml/inference/huggingface_inference_it_test.py
new file mode 100644
index 00000000000..ed442a4b801
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/huggingface_inference_it_test.py
@@ -0,0 +1,80 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""End-to-End test for Hugging Face Inference"""
+
+import logging
+import unittest
+import uuid
+
+import pytest
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.testing.test_pipeline import TestPipeline
+
+try:
+  from apache_beam.examples.inference import huggingface_language_modeling
+  from apache_beam.ml.inference import pytorch_inference_it_test
+except ImportError:
+  raise unittest.SkipTest(
+      "transformers dependencies are not installed. "
+      "Check if transformers, torch, and tensorflow "
+      "is installed.")
+
+
+@pytest.mark.uses_transformers
+@pytest.mark.it_postcommit
+class HuggingFaceInference(unittest.TestCase):
+  @pytest.mark.timeout(1800)
+  def test_hf_language_modeling(self):
+    test_pipeline = TestPipeline(is_integration_test=True)
+    # Path to text file containing some sentences
+    file_of_sentences = 'gs://apache-beam-ml/datasets/custom/hf_sentences.txt'
+    output_file_dir = 'gs://apache-beam-ml/testing/predictions'
+    output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt'])
+
+    model_name = 'stevhliu/my_awesome_eli5_mlm_model'
+
+    extra_opts = {
+        'input': file_of_sentences,
+        'output': output_file,
+        'model_name': model_name,
+    }
+    huggingface_language_modeling.run(
+        test_pipeline.get_full_options_as_args(**extra_opts),
+        save_main_session=False)
+
+    self.assertEqual(FileSystems().exists(output_file), True)
+    predictions = pytorch_inference_it_test.process_outputs(
+        filepath=output_file)
+    actuals_file = 'gs://apache-beam-ml/testing/expected_outputs/test_hf_run_inference_for_masked_lm_actuals.txt'  # pylint: disable=line-too-long
+    actuals = pytorch_inference_it_test.process_outputs(filepath=actuals_file)
+
+    predictions_dict = {}
+    for prediction in predictions:
+      text, predicted_text = prediction.split(';')
+      predictions_dict[text] = predicted_text.strip().lower()
+
+    for actual in actuals:
+      text, actual_predicted_text = actual.split(';')
+      predicted_predicted_text = predictions_dict[text]
+      self.assertEqual(actual_predicted_text, predicted_predicted_text)
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.DEBUG)
+  unittest.main()
diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference_test.py b/sdks/python/apache_beam/ml/inference/huggingface_inference_test.py
new file mode 100644
index 00000000000..763d5ee8d36
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/huggingface_inference_test.py
@@ -0,0 +1,136 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+import shutil
+import tempfile
+import unittest
+from typing import Any
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Union
+
+import pytest
+
+from apache_beam.ml.inference import utils
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.tensorflow_inference_test import FakeTFTensorModel
+from apache_beam.ml.inference.tensorflow_inference_test import _compare_tensor_prediction_result
+
+# pylint: disable=ungrouped-imports
+try:
+  import tensorflow as tf
+  import torch
+  from transformers import AutoModel
+  from transformers import TFAutoModel
+  from apache_beam.ml.inference.huggingface_inference import HuggingFaceModelHandlerTensor
+except ImportError:
+  raise unittest.SkipTest('Transformers dependencies are not installed.')
+
+
+def fake_inference_fn_tensor(
+    batch: Sequence[Union[tf.Tensor, torch.Tensor]],
+    model: Union[AutoModel, TFAutoModel],
+    device,
+    inference_args: Dict[str, Any],
+    model_id: Optional[str] = None) -> Iterable[PredictionResult]:
+  predictions = model.predict(batch, **inference_args)
+  return utils._convert_to_result(batch, predictions, model_id)
+
+
+class FakeTorchModel:
+  def predict(self, input: torch.Tensor):
+    return input
+
+
+@pytest.mark.uses_transformers
+class HuggingFaceInferenceTest(unittest.TestCase):
+  def setUp(self) -> None:
+    self.tmpdir = tempfile.mkdtemp()
+
+  def tearDown(self) -> None:
+    shutil.rmtree(self.tmpdir)
+
+  def test_predict_tensor(self):
+    fake_model = FakeTFTensorModel()
+    inference_runner = HuggingFaceModelHandlerTensor(
+        model_uri='unused',
+        model_class=TFAutoModel,
+        inference_fn=fake_inference_fn_tensor)
+    batched_examples = [tf.constant([1]), tf.constant([10]), tf.constant([100])]
+    expected_predictions = [
+        PredictionResult(ex, pred) for ex,
+        pred in zip(
+            batched_examples,
+            [tf.math.multiply(n, 10) for n in batched_examples])
+    ]
+
+    inferences = inference_runner.run_inference(batched_examples, fake_model)
+    for actual, expected in zip(inferences, expected_predictions):
+      self.assertTrue(_compare_tensor_prediction_result(actual, expected))
+
+  def test_predict_tensor_with_inference_args(self):
+    fake_model = FakeTFTensorModel()
+    inference_runner = HuggingFaceModelHandlerTensor(
+        model_uri='unused',
+        model_class=TFAutoModel,
+        inference_fn=fake_inference_fn_tensor,
+        inference_args={"add": True})
+    batched_examples = [tf.constant([1]), tf.constant([10]), tf.constant([100])]
+    expected_predictions = [
+        PredictionResult(ex, pred) for ex,
+        pred in zip(
+            batched_examples, [
+                tf.math.add(tf.math.multiply(n, 10), 10)
+                for n in batched_examples
+            ])
+    ]
+
+    inferences = inference_runner.run_inference(
+        batched_examples, fake_model, inference_args={"add": True})
+
+    for actual, expected in zip(inferences, expected_predictions):
+      self.assertTrue(_compare_tensor_prediction_result(actual, expected))
+
+  def test_framework_detection_torch(self):
+    fake_model = FakeTorchModel()
+    inference_runner = HuggingFaceModelHandlerTensor(
+        model_uri='unused',
+        model_class=TFAutoModel,
+        inference_fn=fake_inference_fn_tensor)
+    batched_examples = [torch.tensor(1), torch.tensor(10), torch.tensor(100)]
+    inference_runner.run_inference(batched_examples, fake_model)
+    self.assertEqual(inference_runner._framework, "torch")
+
+  def test_framework_detection_tensorflow(self):
+    fake_model = FakeTFTensorModel()
+    inference_runner = HuggingFaceModelHandlerTensor(
+        model_uri='unused',
+        model_class=TFAutoModel,
+        inference_fn=fake_inference_fn_tensor,
+        inference_args={"add": True})
+    batched_examples = [tf.constant([1]), tf.constant([10]), tf.constant([100])]
+    inference_runner.run_inference(
+        batched_examples, fake_model, inference_args={"add": True})
+    self.assertEqual(inference_runner._framework, "tf")
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/apache_beam/ml/inference/huggingface_tests_requirements.txt b/sdks/python/apache_beam/ml/inference/huggingface_tests_requirements.txt
new file mode 100644
index 00000000000..09c1fa8ca90
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/huggingface_tests_requirements.txt
@@ -0,0 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+torch>=1.7.1
+transformers>=4.18.0
+tensorflow>=2.12.0
\ No newline at end of file
diff --git a/sdks/python/pytest.ini b/sdks/python/pytest.ini
index 6e93c5f96e7..7c564235c58 100644
--- a/sdks/python/pytest.ini
+++ b/sdks/python/pytest.ini
@@ -47,15 +47,16 @@ markers =
     # as enabling save_main_session.
     no_xdist: run without pytest-xdist plugin
     # We run these tests with multiple major pyarrow versions (BEAM-11211)
-    uses_pyarrow: tests that utilize pyarrow in some way
+    uses_pyarrow: tests that utilize pyarrow in some way.
     # ML tests
-    uses_pytorch: tests that utilize pytorch in some way
-    uses_sklearn: tests that utilize scikit-learn in some way
-    uses_tensorflow: tests that utilize tensorflow in some way
+    uses_pytorch: tests that utilize pytorch in some way.
+    uses_sklearn: tests that utilize scikit-learn in some way.
+    uses_tensorflow: tests that utilize tensorflow in some way.
     uses_tft: tests that utilizes tensorflow transforms in some way.
-    uses_xgboost: tests that utilize xgboost in some way
+    uses_xgboost: tests that utilize xgboost in some way.
     uses_onnx: tests that utilizes onnx in some way.
-    uses_tf: tests that utilize tensorflow
+    uses_tf: tests that utilize tensorflow.
+    uses_transformers: tests that utilize transformers in some way.
 
 # Default timeout intended for unit tests.
 # If certain tests need a different value, please see the docs on how to
diff --git a/sdks/python/test-suites/direct/common.gradle b/sdks/python/test-suites/direct/common.gradle
index 27e91b4733d..aebdb4cfa00 100644
--- a/sdks/python/test-suites/direct/common.gradle
+++ b/sdks/python/test-suites/direct/common.gradle
@@ -337,13 +337,41 @@ task xgboostInferenceTest {
 
 }
 
+// Transformers RunInference IT tests
+task transformersInferenceTest {
+  dependsOn 'installGcpTest'
+  dependsOn ':sdks:python:sdist'
+  def requirementsFile = "${rootDir}/sdks/python/apache_beam/ml/inference/huggingface_tests_requirements.txt"
+  doFirst {
+      exec {
+        executable 'sh'
+        args '-c', ". ${envdir}/bin/activate && pip install -r $requirementsFile"
+      }
+    }
+  doLast {
+      def testOpts = basicTestOpts
+      def argMap = [
+          "test_opts": testOpts,
+          "suite": "postCommitIT-direct-py${pythonVersionSuffix}",
+          "collect": "uses_transformers and it_postcommit" ,
+          "runner": "TestDirectRunner"
+      ]
+      def cmdArgs = mapToArgString(argMap)
+      exec {
+        executable 'sh'
+        args '-c', ". ${envdir}/bin/activate && ${runScriptsDir}/run_integration_test.sh $cmdArgs"
+      }
+    }
+}
+
 // Add all the RunInference framework IT tests to this gradle task that runs on Direct Runner Post commit suite.
 project.tasks.register("inferencePostCommitIT") {
   dependsOn = [
   'torchInferenceTest',
   'sklearnInferenceTest',
   'tensorflowInferenceTest',
-  'xgboostInferenceTest'
+  'xgboostInferenceTest',
+  'transformersInferenceTest'
   // (TODO) https://github.com/apache/beam/issues/25799
    // uncomment tfx bsl tests once tfx supports protobuf 4.x
   // 'tfxInferenceTest',
diff --git a/sdks/python/test-suites/tox/py38/build.gradle b/sdks/python/test-suites/tox/py38/build.gradle
index 7243a0188ed..208a1d9d39c 100644
--- a/sdks/python/test-suites/tox/py38/build.gradle
+++ b/sdks/python/test-suites/tox/py38/build.gradle
@@ -130,6 +130,19 @@ toxTask "testPy38tensorflow-212", "py38-tensorflow-212", "${posargs}"
 test.dependsOn "testPy38tensorflow-212"
 preCommitPyCoverage.dependsOn "testPy38tensorflow-212"
 
+// Create a test task for each minor version of transformers
+toxTask "testPy38transformers-428", "py38-transformers-428", "${posargs}"
+test.dependsOn "testPy38transformers-428"
+preCommitPyCoverage.dependsOn "testPy38transformers-428"
+
+toxTask "testPy38transformers-429", "py38-transformers-429", "${posargs}"
+test.dependsOn "testPy38transformers-429"
+preCommitPyCoverage.dependsOn "testPy38transformers-429"
+
+toxTask "testPy38transformers-430", "py38-transformers-430", "${posargs}"
+test.dependsOn "testPy38transformers-430"
+preCommitPyCoverage.dependsOn "testPy38transformers-430"
+
 toxTask "whitespacelint", "whitespacelint", "${posargs}"
 
 task archiveFilesToLint(type: Zip) {
diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini
index 3f1b32a20d2..a5c1db416e5 100644
--- a/sdks/python/tox.ini
+++ b/sdks/python/tox.ini
@@ -163,6 +163,7 @@ deps =
   torch
   xgboost
   datatable==1.0.0
+  transformers
 commands =
   time {toxinidir}/scripts/generate_pydoc.sh
 
@@ -406,3 +407,21 @@ commands =
   # Run all XGBoost unit tests
   # Allow exit code 5 (no tests run) so that we can run this command safely on arbitrary subdirectories.
   /bin/sh -c 'pytest -o junit_suite_name={envname} --junitxml=pytest_{envname}.xml -n 6 -m uses_xgboost {posargs}; ret=$?; [ $ret = 5 ] && exit 0 || exit $ret'
+
+[testenv:py{38,39,310,311}-transformers-{428,429,430}]
+deps =
+  -r build-requirements.txt
+  428: transformers>=4.28.0,<4.29.0
+  429: transformers>=4.29.0,<4.30.0
+  430: transformers>=4.30.0,<4.31.0
+  torch>=1.9.0,<1.14.0
+  tensorflow==2.12.0
+extras = test,gcp
+commands =
+  # Log transformers and its dependencies version for debugging
+  /bin/sh -c "pip freeze | grep -E transformers"
+  /bin/sh -c "pip freeze | grep -E torch"
+  /bin/sh -c "pip freeze | grep -E tensorflow"
+  # Run all Transformers unit tests
+  # Allow exit code 5 (no tests run) so that we can run this command safely on arbitrary subdirectories.
+  /bin/sh -c 'pytest -o junit_suite_name={envname} --junitxml=pytest_{envname}.xml -n 6 -m uses_transformers {posargs}; ret=$?; [ $ret = 5 ] && exit 0 || exit $ret'
\ No newline at end of file