You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tv...@apache.org on 2022/07/15 23:15:52 UTC

[beam] branch master updated: Update RunInference documentation (#22250)

This is an automated email from the ASF dual-hosted git repository.

tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new fa028d3dad8 Update RunInference documentation (#22250)
fa028d3dad8 is described below

commit fa028d3dad843d53a56828d998ffcbc7895b56c0
Author: Rebecca Szper <98...@users.noreply.github.com>
AuthorDate: Fri Jul 15 16:15:44 2022 -0700

    Update RunInference documentation (#22250)
    
    Co-authored-by: Anand Inguva <34...@users.noreply.github.com>
    Co-authored-by: Andy Ye <an...@gmail.com>
    Co-authored-by: Anand Inguva <an...@gmail.com>
    Co-authored-by: Anand Inguva <an...@google.com>
---
 .../transforms/elementwise/runinference.py         | 153 ++++++++++++++++
 .../transforms/elementwise/runinference_test.py    | 125 +++++++++++++
 .../documentation/sdks/python-machine-learning.md  | 201 +++++++++++++++++++++
 .../site/content/en/documentation/sdks/python.md   |   8 +-
 .../transforms/python/elementwise/runinference.md  | 105 +++++++++++
 .../en/documentation/transforms/python/overview.md |   1 +
 .../partials/section-menu/en/documentation.html    |   1 +
 .../layouts/partials/section-menu/en/sdks.html     |   1 +
 8 files changed, 594 insertions(+), 1 deletion(-)

diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference.py
new file mode 100644
index 00000000000..162cd995a98
--- /dev/null
+++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference.py
@@ -0,0 +1,153 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+# pylint: disable=reimported
+
+import torch
+
+
+class LinearRegression(torch.nn.Module):
+  def __init__(self, input_dim=1, output_dim=1):
+    super().__init__()
+    self.linear = torch.nn.Linear(input_dim, output_dim)
+
+  def forward(self, x):
+    out = self.linear(x)
+    return out
+
+
+def torch_unkeyed_model_handler(test=None):
+  # [START torch_unkeyed_model_handler]
+  import apache_beam as beam
+  import numpy
+  import torch
+  from apache_beam.ml.inference.base import RunInference
+  from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
+
+  model_state_dict_path = 'gs://apache-beam-samples/run_inference/five_times_table_torch.pt'  # pylint: disable=line-too-long
+  model_class = LinearRegression
+  model_params = {'input_dim': 1, 'output_dim': 1}
+  model_handler = PytorchModelHandlerTensor(
+      model_class=model_class,
+      model_params=model_params,
+      state_dict_path=model_state_dict_path)
+
+  unkeyed_data = numpy.array([10, 40, 60, 90],
+                             dtype=numpy.float32).reshape(-1, 1)
+
+  with beam.Pipeline() as p:
+    predictions = (
+        p
+        | 'InputData' >> beam.Create(unkeyed_data)
+        | 'ConvertNumpyToTensor' >> beam.Map(torch.Tensor)
+        | 'PytorchRunInference' >> RunInference(model_handler=model_handler)
+        | beam.Map(print))
+    # [END torch_unkeyed_model_handler]
+    if test:
+      test(predictions)
+
+
+def torch_keyed_model_handler(test=None):
+  # [START torch_keyed_model_handler]
+  import apache_beam as beam
+  import torch
+  from apache_beam.ml.inference.base import KeyedModelHandler
+  from apache_beam.ml.inference.base import RunInference
+  from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
+
+  model_state_dict_path = 'gs://apache-beam-samples/run_inference/five_times_table_torch.pt'  # pylint: disable=line-too-long
+  model_class = LinearRegression
+  model_params = {'input_dim': 1, 'output_dim': 1}
+  keyed_model_handler = KeyedModelHandler(
+      PytorchModelHandlerTensor(
+          model_class=model_class,
+          model_params=model_params,
+          state_dict_path=model_state_dict_path))
+
+  keyed_data = [("first_question", 105.00), ("second_question", 108.00),
+                ("third_question", 1000.00), ("fourth_question", 1013.00)]
+
+  with beam.Pipeline() as p:
+    predictions = (
+        p
+        | 'KeyedInputData' >> beam.Create(keyed_data)
+        | "ConvertIntToTensor" >>
+        beam.Map(lambda x: (x[0], torch.Tensor([x[1]])))
+        | 'PytorchRunInference' >>
+        RunInference(model_handler=keyed_model_handler)
+        | beam.Map(print))
+    # [END torch_keyed_model_handler]
+    if test:
+      test(predictions)
+
+
+def sklearn_unkeyed_model_handler(test=None):
+  # [START sklearn_unkeyed_model_handler]
+  import apache_beam as beam
+  import numpy
+  from apache_beam.ml.inference.base import RunInference
+  from apache_beam.ml.inference.sklearn_inference import ModelFileType
+  from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy
+
+  sklearn_model_filename = 'gs://apache-beam-samples/run_inference/five_times_table_sklearn.pkl'  # pylint: disable=line-too-long
+  sklearn_model_handler = SklearnModelHandlerNumpy(
+      model_uri=sklearn_model_filename, model_file_type=ModelFileType.PICKLE)
+
+  unkeyed_data = numpy.array([20, 40, 60, 90],
+                             dtype=numpy.float32).reshape(-1, 1)
+  with beam.Pipeline() as p:
+    predictions = (
+        p
+        | "ReadInputs" >> beam.Create(unkeyed_data)
+        | "RunInferenceSklearn" >>
+        RunInference(model_handler=sklearn_model_handler)
+        | beam.Map(print))
+    # [END sklearn_unkeyed_model_handler]
+    if test:
+      test(predictions)
+
+
+def sklearn_keyed_model_handler(test=None):
+  # [START sklearn_keyed_model_handler]
+  import apache_beam as beam
+  from apache_beam.ml.inference.base import KeyedModelHandler
+  from apache_beam.ml.inference.base import RunInference
+  from apache_beam.ml.inference.sklearn_inference import ModelFileType
+  from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy
+
+  sklearn_model_filename = 'gs://apache-beam-samples/run_inference/five_times_table_sklearn.pkl'  # pylint: disable=line-too-long
+  sklearn_model_handler = KeyedModelHandler(
+      SklearnModelHandlerNumpy(
+          model_uri=sklearn_model_filename,
+          model_file_type=ModelFileType.PICKLE))
+
+  keyed_data = [("first_question", 105.00), ("second_question", 108.00),
+                ("third_question", 1000.00), ("fourth_question", 1013.00)]
+
+  with beam.Pipeline() as p:
+    predictions = (
+        p
+        | "ReadInputs" >> beam.Create(keyed_data)
+        | "ConvertDataToList" >> beam.Map(lambda x: (x[0], [x[1]]))
+        | "RunInferenceSklearn" >>
+        RunInference(model_handler=sklearn_model_handler)
+        | beam.Map(print))
+    # [END sklearn_keyed_model_handler]
+    if test:
+      test(predictions)
diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference_test.py
new file mode 100644
index 00000000000..177c08cdfc9
--- /dev/null
+++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference_test.py
@@ -0,0 +1,125 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+import re
+import unittest
+from io import StringIO
+
+import mock
+import pytest
+
+from apache_beam.examples.snippets.util import assert_matches_stdout
+from apache_beam.testing.test_pipeline import TestPipeline
+
+# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports, unused-import
+try:
+  import torch
+  from . import runinference
+except ImportError:
+  raise unittest.SkipTest('PyTorch dependencies are not installed')
+
+# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports, unused-import
+try:
+  from apache_beam.io.gcp.gcsfilesystem import GCSFileSystem
+except ImportError:
+  raise unittest.SkipTest('GCP dependencies are not installed')
+
+
+def check_torch_keyed_model_handler():
+  expected = '''[START torch_keyed_model_handler]
+('first_question', PredictionResult(example=tensor([105.]), inference=tensor([523.6982], grad_fn=<UnbindBackward>)))
+('second_question', PredictionResult(example=tensor([108.]), inference=tensor([538.5867], grad_fn=<UnbindBackward>)))
+('third_question', PredictionResult(example=tensor([1000.]), inference=tensor([4965.4019], grad_fn=<UnbindBackward>)))
+('fourth_question', PredictionResult(example=tensor([1013.]), inference=tensor([5029.9180], grad_fn=<UnbindBackward>)))
+[END torch_keyed_model_handler] '''.splitlines()[1:-1]
+  return expected
+
+
+def check_sklearn_keyed_model_handler(actual):
+  expected = '''[START sklearn_keyed_model_handler]
+('first_question', PredictionResult(example=[105.0], inference=array([525.])))
+('second_question', PredictionResult(example=[108.0], inference=array([540.])))
+('third_question', PredictionResult(example=[1000.0], inference=array([5000.])))
+('fourth_question', PredictionResult(example=[1013.0], inference=array([5065.])))
+[END sklearn_keyed_model_handler] '''.splitlines()[1:-1]
+  assert_matches_stdout(actual, expected)
+
+
+def check_torch_unkeyed_model_handler():
+  expected = '''[START torch_unkeyed_model_handler]
+PredictionResult(example=tensor([10.]), inference=tensor([52.2325], grad_fn=<UnbindBackward>))
+PredictionResult(example=tensor([40.]), inference=tensor([201.1165], grad_fn=<UnbindBackward>))
+PredictionResult(example=tensor([60.]), inference=tensor([300.3724], grad_fn=<UnbindBackward>))
+PredictionResult(example=tensor([90.]), inference=tensor([449.2563], grad_fn=<UnbindBackward>))
+[END torch_unkeyed_model_handler] '''.splitlines()[1:-1]
+  return expected
+
+
+def check_sklearn_unkeyed_model_handler(actual):
+  expected = '''[START sklearn_unkeyed_model_handler]
+PredictionResult(example=array([20.], dtype=float32), inference=array([100.], dtype=float32))
+PredictionResult(example=array([40.], dtype=float32), inference=array([200.], dtype=float32))
+PredictionResult(example=array([60.], dtype=float32), inference=array([300.], dtype=float32))
+PredictionResult(example=array([90.], dtype=float32), inference=array([450.], dtype=float32))
+[END sklearn_unkeyed_model_handler]  '''.splitlines()[1:-1]
+  assert_matches_stdout(actual, expected)
+
+
+@mock.patch('apache_beam.Pipeline', TestPipeline)
+@mock.patch(
+    'apache_beam.examples.snippets.transforms.elementwise.runinference.print',
+    str)
+class RunInferenceTest(unittest.TestCase):
+  def test_sklearn_unkeyed_model_handler(self):
+    runinference.sklearn_unkeyed_model_handler(
+        check_sklearn_unkeyed_model_handler)
+
+  def test_sklearn_keyed_model_handler(self):
+    runinference.sklearn_keyed_model_handler(check_sklearn_keyed_model_handler)
+
+
+@mock.patch('apache_beam.Pipeline', TestPipeline)
+@mock.patch('sys.stdout', new_callable=StringIO)
+class RunInferenceStdoutTest(unittest.TestCase):
+  @pytest.mark.uses_pytorch
+  def test_check_torch_keyed_model_handler(self, mock_stdout):
+    runinference.torch_keyed_model_handler()
+    predicted = mock_stdout.getvalue().splitlines()
+    expected = check_torch_keyed_model_handler()
+    actual_stdout = [line.split(':')[0] for line in predicted]
+    replace_fn = lambda x: re.sub(r"<UnbindBackward\d*>", "<UnbindBackward>", x)
+    actual_stdout = [replace_fn(x) for x in actual_stdout]
+    expected_stdout = [line.split(':')[0] for line in expected]
+    self.assertEqual(actual_stdout, expected_stdout)
+
+  @pytest.mark.uses_pytorch
+  def test_check_torch_unkeyed_model_handler(self, mock_stdout):
+    runinference.torch_unkeyed_model_handler()
+    predicted = mock_stdout.getvalue().splitlines()
+    expected = check_torch_unkeyed_model_handler()
+    actual_stdout = [line.split(':')[0] for line in predicted]
+    replace_fn = lambda x: re.sub(r"<UnbindBackward\d*>", "<UnbindBackward>", x)
+    actual_stdout = [replace_fn(x) for x in actual_stdout]
+    expected_stdout = [line.split(':')[0] for line in expected]
+    self.assertEqual(actual_stdout, expected_stdout)
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/website/www/site/content/en/documentation/sdks/python-machine-learning.md b/website/www/site/content/en/documentation/sdks/python-machine-learning.md
new file mode 100644
index 00000000000..1235a71debc
--- /dev/null
+++ b/website/www/site/content/en/documentation/sdks/python-machine-learning.md
@@ -0,0 +1,201 @@
+---
+type: languages
+title: "Apache Beam Python Machine Learning"
+---
+<!--
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+# Machine Learning
+
+You can use Apache Beam with the RunInference API to use machine learning (ML) models to do local and remote inference with batch and streaming pipelines. Starting with Apache Beam 2.40.0, PyTorch and Scikit-learn frameworks are supported. You can create multiple types of transforms using the RunInference API: the API takes multiple types of setup parameters from model handlers, and the parameter type determines the model implementation.
+
+## Why use the RunInference API?
+
+RunInference takes advantage of existing Apache Beam concepts, such as the the `BatchElements` transform and the `Shared` class, to enable you to use models in your pipelines to create transforms optimized for machine learning inferences. The ability to create arbitrarily complex workflow graphs also allows you to build multi-model pipelines.
+
+### BatchElements PTransform
+
+To take advantage of the optimizations of vectorized inference that many models implement, we added the `BatchElements` transform as an intermediate step before making the prediction for the model. This transform batches elements together. The batched elements are then applied with a transformation for the particular framework of RunInference. For example, for numpy `ndarrays`, we call `numpy.stack()`,  and for torch `Tensor` elements, we call `torch.stack()`.
+
+To customize the settings for `beam.BatchElements`, in `ModelHandler`, override the `batch_elements_kwargs` function. For example, use `min_batch_size` to set the lowest number of elements per batch or `max_batch_size` to set the highest number of elements per batch.
+
+For more information, see the [`BatchElements` transform documentation](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.util.html#apache_beam.transforms.util.BatchElements).
+
+### Shared helper class
+
+Using the `Shared` class within RunInference implementation allows us to load the model only once per process and share it with all DoFn instances created in that process. This feature reduces memory consumption and model loading time. For more information, see the
+[`Shared` class documentation](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/utils/shared.py#L20).
+
+### Multi-model pipelines
+
+The RunInference API can be composed into multi-model pipelines. Multi-model pipelines can be useful for A/B testing or for building out ensembles that are comprised of models that perform tokenization, sentence segmentation, part-of-speech tagging, named entity extraction, language detection, coreference resolution, and more.
+
+## Modify a pipeline to use an ML model
+
+To use the RunInference transform, add the following code to your pipeline:
+
+```
+from apache_beam.ml.inference.base import RunInference
+with pipeline as p:
+   predictions = ( p |  'Read' >> beam.ReadFromSource('a_source')
+                     | 'RunInference' >> RunInference(<model_handler>)
+```
+Where `model_handler` is the model handler setup code.
+
+To import models, you need to configure a `ModelHandler` object that wraps the underlying model. Which `ModelHandler` you import depends on the framework and type of data structure that contains the inputs. The following examples show some ModelHandlers that you might want to import.
+
+```
+from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy
+from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerPandas
+from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
+from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor
+```
+### Use pre-trained models
+
+The section provides requirements for using pre-trained models with PyTorch and Scikit-learn
+
+#### PyTorch
+
+You need to provide a path to a file that contains the model saved weights. This path must be accessible by the pipeline. To use pre-trained models with the RunInference API and the PyTorch framework, complete the following steps:
+
+1. Download the pre-trained weights and host them in a location that the pipeline can access.
+2. Pass the path of the model weights to the PyTorch `ModelHandler` by using the following code: `state_dict_path=<path_to_weights>`.
+
+#### Scikit-learn
+
+You need to provide a path to a file that contains the pickled Scikit-learn model. This path must be accessible by the pipeline. To use pre-trained models with the RunInference API and the Scikit-learn framework, complete the following steps:
+
+1. Download the pickled model class and host it in a location that the pipeline can access.
+2. Pass the path of the model to the Sklearn `ModelHandler` by using the following code:
+   `model_uri=<path_to_pickled_file>` and `model_file_type: <ModelFileType>`, where you can specify
+   `ModelFileType.PICKLE` or `ModelFileType.JOBLIB`, depending on how the model was serialized.
+
+### Use multiple models
+
+You can also use the RunInference transform to add multiple inference models to your pipeline.
+
+#### A/B Pattern
+
+```
+with pipeline as p:
+   data = p | 'Read' >> beam.ReadFromSource('a_source')
+   model_a_predictions = data | RunInference(<model_handler_A>)
+   model_b_predictions = data | RunInference(<model_handler_B>)
+```
+
+Where `model_handler_A` and `model_handler_B` are the model handler setup code.
+
+#### Ensemble Pattern
+
+```
+with pipeline as p:
+   data = p | 'Read' >> beam.ReadFromSource('a_source')
+   model_a_predictions = data | RunInference(<model_handler_A>)
+   model_b_predictions = model_a_predictions | beam.Map(some_post_processing) | RunInference(<model_handler_B>)
+```
+
+Where `model_handler_A` and `model_handler_B` are the model handler setup code.
+
+### Use a keyed ModelHandler
+
+If a key is attached to the examples, wrap the `KeyedModelHandler` around the `ModelHandler` object:
+
+```
+from apache_beam.ml.inference.base import KeyedModelHandler
+keyed_model_handler = KeyedModelHandler(PytorchModelHandlerTensor(...))
+with pipeline as p:
+   data = p | beam.Create([
+      ('img1', torch.tensor([[1,2,3],[4,5,6],...])),
+      ('img2', torch.tensor([[1,2,3],[4,5,6],...])),
+      ('img3', torch.tensor([[1,2,3],[4,5,6],...])),
+   ])
+   predictions = data | RunInference(KeyedModelHandler)
+```
+
+### Use the PredictionResults object
+
+When doing a prediction in Apache Beam, the output `PCollection` includes both the keys of the input examples and the inferences. Including both these items in the output allows you to find the input that determined the predictions.
+
+The `PredictionResult` is a `NamedTuple` object that contains both the input and the inferences, named  `example` and  `inference`, respectively. When keys are passed with the input data to the RunInference transform, the output `PCollection` returns a `Tuple[str, PredictionResult]`, which is the key and the `PredictionResult` object. Your pipeline interacts with a `PredictionResult` object in steps after the RunInference transform.
+
+```
+class PostProcessor(beam.DoFn):
+    def process(self, element: Tuple[str, PredictionResult]):
+       key, prediction_result = element
+       inputs = prediction_result.example
+       predictions = prediction_result.inference
+
+       # Post-processing logic
+       result = ...
+
+       yield (key, result)
+
+with pipeline as p:
+    output = (
+        p | 'Read' >> beam.ReadFromSource('a_source')
+                | 'PyTorchRunInference' >> RunInference(<keyed_model_handler>)
+                | 'ProcessOutput' >> beam.ParDo(PostProcessor()))
+```
+
+If you need to use this object explicitly, include the following line in your pipeline to import the object:
+
+```
+from apache_beam.ml.inference.base import PredictionResult
+```
+
+For more information, see the [`PredictionResult` documentation](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/ml/inference/base.py#L65).
+
+## Run a machine learning pipeline
+
+For detailed instructions explaining how to build and run a pipeline that uses ML models, see the
+[Example RunInference API pipelines](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/examples/inference) on GitHub.
+
+## Troubleshooting
+
+If you run into problems with your pipeline or job, this section lists issues that you might encounter and provides suggestions for how to fix them.
+
+### Incorrect inferences in the PredictionResult object
+
+In some cases, the `PredictionResults` output might not include the correct predictions in the `inferences` field. This issue occurs when you use a model whose inferences return a dictionary that maps keys to predictions and other metadata. An example return type is `Dict[str, Tensor]`.
+
+The RunInference API currently expects outputs to be an `Iterable[Any]`. Example return types are `Iterable[Tensor]` or `Iterable[Dict[str, Tensor]]`. When RunInference zips the inputs with the predictions, the predictions iterate over the dictionary keys instead of the batch elements. The result is that the key name is preserved but the prediction tensors are discarded. For more information, see the [Pytorch RunInference PredictionResult is a Dict](https://github.com/apache/beam/issues/ [...]
+
+To work with the current RunInference implementation, you can create a wrapper class that overrides the `model(input)` call. In PyTorch, for example, your wrapper would override the `forward()` function and return an output with the appropriate format of `List[Dict[str, torch.Tensor]]`. For more information, see our [HuggingFace language modeling example](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py#L49).
+
+### Unable to batch tensor elements
+
+RunInference uses dynamic batching. However, the RunInference API cannot batch tensor elements of different sizes, so samples passed to the RunInferene transform must be the same dimension or length. If you provide images of different sizes or word embeddings of different lengths, the following error might occur:
+
+`
+File "/beam/sdks/python/apache_beam/ml/inference/pytorch_inference.py", line 232, in run_inference
+batched_tensors = torch.stack(key_to_tensor_list[key])
+RuntimeError: stack expects each tensor to be equal size, but got [12] at entry 0 and [10] at entry 1 [while running 'PyTorchRunInference/ParDo(_RunInferenceDoFn)']
+`
+
+To avoid this issue, either use elements of the same size, or disable batching.
+
+**Option 1: Use elements of the same size**
+
+Use elements of the same size or resize the inputs. For computer vision applications, resize image inputs so that they have the same dimensions. For natural language processing (NLP) applications that have text of varying length, resize the text or word embeddings to make them the same length. When working with texts of varying length, resizing might not be possible. In this scenario, you could disable batching (see option 2).
+
+**Option 2: Disable batching**
+
+Disable batching by overriding the `batch_elements_kwargs` function in your ModelHandler and setting the maximum batch size (`max_batch_size`) to one: `max_batch_size=1`. For more information, see
+[BatchElements PTransforms](/documentation/sdks/python-machine-learning/#batchelements-ptransform). For an example, see our [language modeling example](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py).
+
+## Related links
+
+* [RunInference transforms](/documentation/transforms/python/elementwise/runinference)
+* [RunInference API pipeline examples](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/examples/inference)
+* [apache_beam.ml.inference package](/releases/pydoc/current/apache_beam.ml.inference.html#apache_beam.ml.inference.RunInference)
\ No newline at end of file
diff --git a/website/www/site/content/en/documentation/sdks/python.md b/website/www/site/content/en/documentation/sdks/python.md
index 17e67fc0451..80c26c258d5 100644
--- a/website/www/site/content/en/documentation/sdks/python.md
+++ b/website/www/site/content/en/documentation/sdks/python.md
@@ -46,7 +46,13 @@ new I/O connectors. See the [Developing I/O connectors overview](/documentation/
 for information about developing new I/O connectors and links to
 language-specific implementation guidance.
 
-## Using Beam Python SDK in your ML pipelines
+## Making machine learning inferences with Python
+
+To integrate machine learning models into your pipelines for making inferences, use the RunInference API for PyTorch and Scikit-learn models. If you are using TensorFlow models, you can make use of the
+[library from `tfx_bsl`](https://github.com/tensorflow/tfx-bsl/tree/master/tfx_bsl/beam).
+
+You can create multiple types of transforms using the RunInference API: the API takes multiple types of setup parameters from model handlers, and the parameter type determines the model implementation. For more information,
+see [Machine Learning](/documentation/sdks/python-machine-learning).
 
 [TensorFlow Extended (TFX)](https://www.tensorflow.org/tfx) is an end-to-end platform for deploying production ML pipelines. TFX is integrated with Beam. For more information, see [TFX user guide](https://www.tensorflow.org/tfx/guide).
 
diff --git a/website/www/site/content/en/documentation/transforms/python/elementwise/runinference.md b/website/www/site/content/en/documentation/transforms/python/elementwise/runinference.md
new file mode 100644
index 00000000000..42455aa073a
--- /dev/null
+++ b/website/www/site/content/en/documentation/transforms/python/elementwise/runinference.md
@@ -0,0 +1,105 @@
+---
+title: "RunInference"
+---
+<!--
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+# RunInference
+
+{{< localstorage language language-py >}}
+
+{{< button-pydoc path="apache_beam.ml.inference" class="RunInference" >}}
+
+Uses models to do local and remote inference. A `RunInference` transform performs inference on a `PCollection` of examples using a machine learning (ML) model. The transform outputs a `PCollection` that contains the input examples and output predictions.
+
+You must have Apache Beam 2.40.0 or later installed to run these pipelines.
+
+See more [RunInference API pipeline examples](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/examples/inference).
+
+## Examples
+
+In the following examples, we explore how to create pipelines that use the Beam RunInference API to make predictions based on models.
+
+### Example 1: PyTorch unkeyed model
+
+In this example, we create a pipeline that uses a PyTorch RunInference transform on unkeyed data.
+
+{{< highlight language="py" file="sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference.py"
+  class="notebook-skip" >}}
+{{< code_sample "sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference.py" torch_unkeyed_model_handler >}}
+{{</ highlight >}}
+
+{{< paragraph class="notebook-skip" >}}
+Output:
+{{< /paragraph >}}
+{{< highlight class="notebook-skip" >}}
+{{< code_sample "sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference_test.py" torch_unkeyed_model_handler >}}
+{{< /highlight >}}
+
+### Example 2: PyTorch keyed model
+
+In this example, we create a pipeline that uses a PyTorch RunInference transform on keyed data.
+
+{{< highlight language="py" file="sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference.py"
+  class="notebook-skip" >}}
+{{< code_sample "sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference.py" torch_keyed_model_handler >}}
+{{</ highlight >}}
+
+{{< paragraph class="notebook-skip" >}}
+Output:
+{{< /paragraph >}}
+
+{{< highlight class="notebook-skip" >}}
+{{< code_sample "sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference_test.py" torch_keyed_model_handler >}}
+{{< /highlight >}}
+
+### Example 3: Sklearn unkeyed model
+
+In this example, we create a pipeline that uses an SKlearn RunInference transform on unkeyed data.
+
+{{< highlight language="py" file="sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference.py"
+  class="notebook-skip" >}}
+{{< code_sample "sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference.py" sklearn_unkeyed_model_handler >}}
+{{</ highlight >}}
+
+{{< paragraph class="notebook-skip" >}}
+Output:
+{{< /paragraph >}}
+
+{{< highlight class="notebook-skip" >}}
+{{< code_sample "sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference_test.py" sklearn_unkeyed_model_handler >}}
+{{< /highlight >}}
+
+### Example 4: Sklearn keyed model
+
+In this example, we create a pipeline that uses an SKlearn RunInference transform on keyed data.
+
+{{< highlight language="py" file="sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference.py"
+  class="notebook-skip" >}}
+{{< code_sample "sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference.py" sklearn_keyed_model_handler >}}
+{{</ highlight >}}
+
+{{< paragraph class="notebook-skip" >}}
+Output:
+{{< /paragraph >}}
+
+{{< highlight class="notebook-skip" >}}
+{{< code_sample "sdks/python/apache_beam/examples/snippets/transforms/elementwise/runinference_test.py" sklearn_keyed_model_handler >}}
+{{< /highlight >}}
+
+## Related transforms
+
+Not applicable.
+
+{{< button-pydoc path="apache_beam.ml.inference" class="RunInference" >}}
diff --git a/website/www/site/content/en/documentation/transforms/python/overview.md b/website/www/site/content/en/documentation/transforms/python/overview.md
index 71d5f1e0efb..d30af75352b 100644
--- a/website/www/site/content/en/documentation/transforms/python/overview.md
+++ b/website/www/site/content/en/documentation/transforms/python/overview.md
@@ -33,6 +33,7 @@ limitations under the License.
   function.</td></tr>
   <tr><td><a href="/documentation/transforms/python/elementwise/regex">Regex</a></td><td>Filters input string elements based on a regex. May also transform them based on the matching groups.</td></tr>
   <tr><td><a href="/documentation/transforms/python/elementwise/reify">Reify</a></td><td>Transforms for converting between explicit and implicit form of various Beam values.</td></tr>
+  <tr><td><a href="/documentation/transforms/python/elementwise/runinference">RunInference</a></td><td>Uses machine learning (ML) models to do local and remote inference.</td></tr>
   <tr><td><a href="/documentation/transforms/python/elementwise/tostring">ToString</a></td><td>Transforms every element in an input collection a string.</td></tr>
   <tr><td><a href="/documentation/transforms/python/elementwise/withtimestamps">WithTimestamps</a></td><td>Applies a function to determine a timestamp to each element in the output collection,
   and updates the implicit timestamp associated with each input. Note that it is only
diff --git a/website/www/site/layouts/partials/section-menu/en/documentation.html b/website/www/site/layouts/partials/section-menu/en/documentation.html
index 59a3ebf3315..e6d1d5c742f 100644
--- a/website/www/site/layouts/partials/section-menu/en/documentation.html
+++ b/website/www/site/layouts/partials/section-menu/en/documentation.html
@@ -240,6 +240,7 @@
             <li><a href="/documentation/transforms/python/elementwise/partition/">Partition</a></li>
             <li><a href="/documentation/transforms/python/elementwise/regex/">Regex</a></li>
             <li><a href="/documentation/transforms/python/elementwise/reify/">Reify</a></li>
+            <li><a href="/documentation/transforms/python/elementwise/runinference/">RunInference</a></li>
             <li><a href="/documentation/transforms/python/elementwise/tostring/">ToString</a></li>
             <li><a href="/documentation/transforms/python/elementwise/values/">Values</a></li>
             <li><a href="/documentation/transforms/python/elementwise/withtimestamps/">WithTimestamps</a></li>
diff --git a/website/www/site/layouts/partials/section-menu/en/sdks.html b/website/www/site/layouts/partials/section-menu/en/sdks.html
index d1bd51b7189..9b891705c5c 100644
--- a/website/www/site/layouts/partials/section-menu/en/sdks.html
+++ b/website/www/site/layouts/partials/section-menu/en/sdks.html
@@ -39,6 +39,7 @@
     <li><a href="/documentation/sdks/python-dependencies/">Python SDK dependencies</a></li>
     <li><a href="/documentation/sdks/python-streaming/">Python streaming pipelines</a></li>
     <li><a href="/documentation/sdks/python-type-safety/">Ensuring Python type safety</a></li>
+    <li><a href="/documentation/sdks/python-machine-learning/">Machine Learning</a></li>
     <li><a href="/documentation/sdks/python-pipeline-dependencies/">Managing pipeline dependencies</a></li>
     <li><a href="/documentation/sdks/python-multi-language-pipelines/">Python multi-language pipelines quickstart</a></li>
   </ul>