You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2022/10/06 17:14:05 UTC

[beam] branch master updated: Content/multi model pipelines (#23498)

This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 9b2f87d0201 Content/multi model pipelines (#23498)
9b2f87d0201 is described below

commit 9b2f87d0201e8923a021c6bedbe0f64e37704014
Author: Philippe Moussalli <ph...@gmail.com>
AuthorDate: Thu Oct 6 19:13:57 2022 +0200

    Content/multi model pipelines (#23498)
    
    * Created using Colaboratory
    
    * add new website content for mulit-model pipelines
    
    * add example notebook with blip/clip
    
    * add multi-model website content
    
    * add multi-model example notebook
    
    * fix notebook typo and link issues
    
    * add reference to multimodel page in beam-ml master page
    
    * remove whitespaces
    
    * fix wrong resource link
    
    * change references of ensemble models to cascade models
    
    Co-authored-by: Philippe Moussalli <ph...@ml6.eu>
---
 .../beam-ml/run_inference_multi_model.ipynb        | 994 +++++++++++++++++++++
 .../en/documentation/ml/multi-model-pipelines.md   | 101 +++
 .../site/content/en/documentation/ml/overview.md   |   1 +
 .../documentation/sdks/python-machine-learning.md  |   6 +-
 .../partials/section-menu/en/documentation.html    |   2 +
 5 files changed, 1101 insertions(+), 3 deletions(-)

diff --git a/examples/notebooks/beam-ml/run_inference_multi_model.ipynb b/examples/notebooks/beam-ml/run_inference_multi_model.ipynb
new file mode 100644
index 00000000000..65624eb74b1
--- /dev/null
+++ b/examples/notebooks/beam-ml/run_inference_multi_model.ipynb
@@ -0,0 +1,994 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "provenance": [],
+      "collapsed_sections": []
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    },
+    "language_info": {
+      "name": "python"
+    },
+    "accelerator": "GPU"
+  },
+  "cells": [
+    {
+      "cell_type": "code",
+      "source": [
+        "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
+        "\n",
+        "# Licensed to the Apache Software Foundation (ASF) under one\n",
+        "# or more contributor license agreements. See the NOTICE file\n",
+        "# distributed with this work for additional information\n",
+        "# regarding copyright ownership. The ASF licenses this file\n",
+        "# to you under the Apache License, Version 2.0 (the\n",
+        "# \"License\"); you may not use this file except in compliance\n",
+        "# with the License. You may obtain a copy of the License at\n",
+        "#\n",
+        "#   http://www.apache.org/licenses/LICENSE-2.0\n",
+        "#\n",
+        "# Unless required by applicable law or agreed to in writing,\n",
+        "# software distributed under the License is distributed on an\n",
+        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
+        "# KIND, either express or implied. See the License for the\n",
+        "# specific language governing permissions and limitations\n",
+        "# under the License"
+      ],
+      "metadata": {
+        "cellView": "form",
+        "id": "L_6L5GU7jyR_"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# Ensemble model using an image captioning and ranking example"
+      ],
+      "metadata": {
+        "id": "gPCMXWgOMt_0"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "A single machine learning  model may not always be the perfect solution for a give task. Oftentimes, machine learning model tasks involve aggregating mutliple models together to produce one optimal predictive model and boost performance. \n",
+        " \n",
+        "\n",
+        "In this notebook, we will shows you an example on how to implement a cascade model in Beam using the [RunInference API](https://beam.apache.org/documentation/sdks/python-machine-learning/). The RunInference API enables you to run your Beam transfroms as part of your pipeline for optimal machine learning inference in beam.     \n",
+        "\n",
+        "Make sure to checkout this [notebook](https://colab.research.google.com/drive/111USL4VhUa0xt_mKJxl5nC1YLOC8_yF4?usp=sharing#scrollTo=746b67a7-3562-467f-bea3-d8cd18c14927) to get familiar with the RunInference API."
+      ],
+      "metadata": {
+        "id": "6vZWSLyuM_P4"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Use case: Image captioning with cascade models "
+      ],
+      "metadata": {
+        "id": "i1uyzlj3s3e_"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Image captioning has various different applications such as image indexing for information retreival, usage in virtual assistants and many other natural language processing applications.\n",
+        "\n",
+        "We want to use beam to generate captions on a a large set of images. Beam is the ideal tool to handle this. We will use two models for this task:\n",
+        "\n",
+        "* [BLIP](https://github.com/salesforce/BLIP): Used to generate a set of candidate captions for a given image. \n",
+        "* [CLIP](https://github.com/openai/CLIP): Used to rank the generated captions by the order in which they better represent the the given image."
+      ],
+      "metadata": {
+        "id": "cP1sBhNacS8b"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "![image_captioning_example.png]( [...]
+      ],
+      "metadata": {
+        "id": "PT8L9hbTcytZ"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "The steps needed to build this pipeline can be summarized as follows:\n",
+        "* Read the images.\n",
+        "* Preprocess the images for caption generation for inference with the BLIP model.\n",
+        "* Inference with BLIP to generate a list of caption candidates . \n",
+        "* Aggregate the generated captions with their source image.\n",
+        "* Preprocess the aggregated image-caption pair to rank them with CLIP.\n",
+        "* Inference wih CLIP to generated the caption ranking. \n",
+        "* Print the image names and the captions sorted according to their ranking\n",
+        "\n",
+        "\n",
+        "The following image illustrates the steps that will be followed in the inference pipelines in more details:"
+      ],
+      "metadata": {
+        "id": "lBPfy-bYgLuD"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "#@title Diagram\n",
+        "from IPython.display import Image\n",
+        "Image(url='https://storage.googleapis.com/apache-beam-samples/image_captioning/beam_ensemble_diagram.png', width=2000)"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 440
+        },
+        "id": "3suC5woJLW_N",
+        "outputId": "d2f9f67b-361b-4ae9-f9db-ce2ff9abd509",
+        "cellView": "form"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/html": [
+              "<img src=\"https://storage.googleapis.com/apache-beam-samples/image_captioning/beam_ensemble_diagram.png\" width=\"2000\"/>"
+            ],
+            "text/plain": [
+              "<IPython.core.display.Image object>"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 3
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Dependencies"
+      ],
+      "metadata": {
+        "id": "GULu36WYx5MB"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "The RunInference library is available in Apache Beam version **2.40** or later. "
+      ],
+      "metadata": {
+        "id": "E0uy4-nWNdBa"
+      }
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "id": "tTUZpG9_q-OW",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "14ffb9db-b67a-40bb-bf39-3b472b188898"
+      },
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "\u001b[K     |████████████████████████████████| 2.0 MB 4.7 MB/s \n",
+            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m34.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m18.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m880.6/880.6 kB\u001b[0m \u001b[31m62.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.3/3.3 MB\u001b[0m \u001b[31m75.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+            "\u001b[?25h  Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+            "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m377.0/377.0 kB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+            "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.1/53.1 kB\u001b[0m \u001b[31m2.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+            "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+            "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m235.4/235.4 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+            "\u001b[?25h  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+            "  Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n",
+            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+            "  Building wheel for fairscale (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+            "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+            "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+            "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+            "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+            "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+            "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+            "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+            "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+            "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+            "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+            "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+            "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+            "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+            "\u001b[0m\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+            "tensorflow 2.8.2+zzzcolab20220719082949 requires protobuf<3.20,>=3.9.2, but you have protobuf 3.20.2 which is incompatible.\u001b[0m\u001b[31m\n",
+            "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+            "\u001b[0m"
+          ]
+        }
+      ],
+      "source": [
+        "!pip install --upgrade pip --quiet\n",
+        "!pip install transformers==4.15.0 --quiet\n",
+        "!pip install timm==0.4.12 --quiet\n",
+        "!pip install ftfy==6.1.1 --quiet\n",
+        "!pip install spacy==3.4.1 --quiet\n",
+        "!pip install fairscale==0.4.4 --quiet\n",
+        "!pip install apache_beam[gcp]>=2.40.0  \n",
+        "\n",
+        "# restart the runtime in order to use newly installed versions\n",
+        "exit() "
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import requests\n",
+        "import os \n",
+        "import io\n",
+        "from io import BytesIO\n",
+        "from typing import Iterator\n",
+        "from typing import Iterable\n",
+        "from typing import Tuple\n",
+        "from typing import Optional\n",
+        "from typing import Dict\n",
+        "from typing import List\n",
+        "from typing import Any\n",
+        "\n",
+        "import apache_beam as beam\n",
+        "from apache_beam.ml.inference.base import PredictionResult\n",
+        "from apache_beam.options.pipeline_options import PipelineOptions\n",
+        "from apache_beam.options.pipeline_options import SetupOptions\n",
+        "from apache_beam.ml.inference.base import KeyedModelHandler\n",
+        "from apache_beam.ml.inference.base import PredictionResult\n",
+        "from apache_beam.ml.inference.base import RunInference\n",
+        "from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor\n",
+        "from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor\n",
+        "from transformers import CLIPProcessor\n",
+        "from transformers import CLIPTokenizer\n",
+        "from transformers import CLIPModel\n",
+        "from transformers import CLIPConfig\n",
+        "from transformers import CLIPFeatureExtractor\n",
+        "import torch\n",
+        "from torchvision import transforms\n",
+        "from torchvision.transforms.functional import InterpolationMode\n",
+        "import numpy as np\n",
+        "import matplotlib.pyplot as plt\n",
+        "from PIL import Image"
+      ],
+      "metadata": {
+        "id": "6JbQqfYuvx1f"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "### CLIP"
+      ],
+      "metadata": {
+        "id": "iMsN4vUXilTg"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "! git lfs install\n",
+        "! git clone https://huggingface.co/openai/clip-vit-base-patch32"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "Ud4sUXV2x8LO",
+        "outputId": "aa7d593c-107d-40ab-9212-2aef61fa64a7"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Error: Failed to call git rev-parse --git-dir --show-toplevel: \"fatal: not a git repository (or any of the parent directories): .git\\n\"\n",
+            "Git LFS initialized.\n",
+            "Cloning into 'clip-vit-base-patch32'...\n",
+            "remote: Enumerating objects: 48, done.\u001b[K\n",
+            "remote: Counting objects: 100% (48/48), done.\u001b[K\n",
+            "remote: Compressing objects: 100% (28/28), done.\u001b[K\n",
+            "remote: Total 48 (delta 22), reused 42 (delta 19), pack-reused 0\u001b[K\n",
+            "Unpacking objects: 100% (48/48), done.\n",
+            "Filtering content: 100% (3/3), 1.69 GiB | 91.86 MiB/s, done.\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# CLIP model and component configs paths\n",
+        "clip_feature_extractor_config_path = '/content/clip-vit-base-patch32/preprocessor_config.json'\n",
+        "clip_tokenizer_vocab_config_path = '/content/clip-vit-base-patch32/vocab.json'\n",
+        "clip_merges_config_path = '/content/clip-vit-base-patch32/merges.txt'\n",
+        "clip_model_config_path = '/content/clip-vit-base-patch32/config.json'\n",
+        "clip_state_dict_path = '/content/clip-vit-base-patch32/pytorch_model.bin'\n"
+      ],
+      "metadata": {
+        "id": "cDbmVBKuZoWE"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "### BLIP"
+      ],
+      "metadata": {
+        "id": "Rg9mKAWnR8s4"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "!git clone https://github.com/salesforce/BLIP\n",
+        "%cd /content/BLIP"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "g4-6WwqUtxea",
+        "outputId": "9a1f2ebf-9fb6-4621-a4cd-7470b5d24964"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Cloning into 'BLIP'...\n",
+            "remote: Enumerating objects: 274, done.\u001b[K\n",
+            "remote: Counting objects: 100% (109/109), done.\u001b[K\n",
+            "remote: Compressing objects: 100% (29/29), done.\u001b[K\n",
+            "remote: Total 274 (delta 89), reused 80 (delta 80), pack-reused 165\u001b[K\n",
+            "Receiving objects: 100% (274/274), 7.67 MiB | 26.26 MiB/s, done.\n",
+            "Resolving deltas: 100% (146/146), done.\n",
+            "/content/BLIP\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "from BLIP.models.blip import blip_decoder\n",
+        "\n",
+        "!gdown 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'\n",
+        "# The blip model is saved as a checkoint, load it and save it as a state dict since RunInference required \n",
+        "# a state dict for model instantiation \n",
+        "blip_state_dict_path = '/content/BLIP/blip_state_dict.pth'\n",
+        "torch.save(torch.load('/content/BLIP/model*_base_caption.pth')['model'], blip_state_dict_path)"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "GCvOP_iZh41c",
+        "outputId": "1e4779dc-ed0c-4450-8590-4d9caeab1083"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Downloading...\n",
+            "From: https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth\n",
+            "To: /content/BLIP/model*_base_caption.pth\n",
+            "100% 896M/896M [00:04<00:00, 193MB/s] \n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## I/O helper functions "
+      ],
+      "metadata": {
+        "id": "FGHgvycOyicj"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "class ReadImagesFromUrl(beam.DoFn):\n",
+        "  \"\"\"\n",
+        "  Read an image from a given url and return a tuple of the images_url\n",
+        "  and image data.\n",
+        "  \"\"\"\n",
+        "  def process(self, element: str) -> Tuple[str, Image.Image]:\n",
+        "    response = requests.get(element)\n",
+        "    image = Image.open(BytesIO(response.content)).convert('RGB')\n",
+        "    return [(element, image)]\n",
+        "\n",
+        "\n",
+        "class FormatCaptions(beam.DoFn):\n",
+        "  \"\"\"\n",
+        "  Print the image name and it's most relevant captions after CLIP ranking.\n",
+        "  \"\"\"\n",
+        "  def __init__(self, number_of_top_captions: int):\n",
+        "    self._number_of_top_captions = number_of_top_captions\n",
+        "\n",
+        "  def process(self, element: Tuple[str, List[str]]):\n",
+        "    image_url, caption_list = element\n",
+        "    caption_list = caption_list[:self._number_of_top_captions]\n",
+        "    img_name = os.path.basename(image_url).rsplit('.')[0]\n",
+        "    print(f'Image: {img_name}')\n",
+        "    print(f'\\tTop {self._number_of_top_captions} captions ranked by CLIP:')\n",
+        "    for caption_rank, caption_prob_pair in enumerate(caption_list):\n",
+        "      print(f'\\t\\t{caption_rank+1}: {caption_prob_pair[0]}. (Caption probability: {caption_prob_pair[1]:.2f})')\n",
+        "    print('\\n')"
+      ],
+      "metadata": {
+        "id": "1Lz3yGuqlAJ_"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Intermediate Processing functions"
+      ],
+      "metadata": {
+        "id": "ogJBy2kfWo6i"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Here we define the preprocessing and postprocessing function for each of the models.\n",
+        "\n",
+        "> ℹ️ We use `DoFn.setup()` to prepare the instance for processing bundles of elements by initializing and cache the processing transform resources. As such, we avoid unnecessary re-initializations on every invocation to the processing method."
+      ],
+      "metadata": {
+        "id": "wEViP715fes4"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "### BLIP"
+      ],
+      "metadata": {
+        "id": "X1UGv6bbyNxY"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "class PreprocessBLIPInput(beam.DoFn):\n",
+        "\n",
+        "  \"\"\"\n",
+        "  Process the raw image input to a format suitable for BLIP Inference. The processed\n",
+        "  images are duplicated to the number of desired captions per image. \n",
+        "\n",
+        "  Preprocessing transformation taken from: \n",
+        "  https://github.com/salesforce/BLIP/blob/d10be550b2974e17ea72e74edc7948c9e5eab884/predict.py\n",
+        "  \"\"\"\n",
+        "\n",
+        "  def __init__(self, captions_per_image: int):\n",
+        "    self._captions_per_image = captions_per_image\n",
+        "\n",
+        "  def setup(self):\n",
+        "    \n",
+        "    # Initialize image transformer\n",
+        "    self._transform = transforms.Compose([\n",
+        "      transforms.Resize((384, 384),interpolation=InterpolationMode.BICUBIC),\n",
+        "      transforms.ToTensor(),\n",
+        "      transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
+        "    ])\n",
+        "\n",
+        "  def process(self, element):\n",
+        "    image_url, image = element \n",
+        "    # This should be changed when this ticket is resolved https://github.com/apache/beam/issues/21863\n",
+        "    preprocessed_img = self._transform(image).unsqueeze(0)\n",
+        "    preprocessed_img = preprocessed_img.repeat(self._captions_per_image, 1, 1, 1)\n",
+        "    # Parse the processed input to a dictionary to a format suitable for RunInference\n",
+        "    preprocessed_dict = {'inputs': preprocessed_img}\n",
+        "\n",
+        "    return [(image_url, preprocessed_dict)]\n",
+        "\n",
+        "class PostprocessBLIPOutput(beam.DoFn):\n",
+        "  \"\"\"\n",
+        "  Process the PredictionResult to get the generated image captions\n",
+        "  \"\"\"\n",
+        "  def process(self, element : Tuple[str, Iterable[PredictionResult]]):\n",
+        "    image_url, prediction = element \n",
+        "\n",
+        "    return [(image_url, prediction.inference)]"
+      ],
+      "metadata": {
+        "id": "A1s_QQoUctkc"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "###CLIP"
+      ],
+      "metadata": {
+        "id": "EZHfa1KzWWDI"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "class PreprocessCLIPInput(beam.DoFn):\n",
+        "\n",
+        "  \"\"\"\n",
+        "  Process the image-caption pair to a format suitable for CLIP inference. \n",
+        "\n",
+        "  After grouping the raw images with the generated captions, we need to \n",
+        "  preprocess them before passing them to the ranking stage (CLIP model).\n",
+        "  \"\"\"\n",
+        "\n",
+        "  def __init__(self,\n",
+        "               feature_extractor_config_path: str,\n",
+        "               tokenizer_vocab_config_path: str,\n",
+        "               merges_file_config_path: str):\n",
+        "\n",
+        "    self._feature_extractor_config_path = feature_extractor_config_path\n",
+        "    self._tokenizer_vocab_config_path = tokenizer_vocab_config_path \n",
+        "    self._merges_file_config_path = merges_file_config_path\n",
+        "\n",
+        "\n",
+        "  def setup(self):\n",
+        "    \n",
+        "    # Initialize the CLIP feature extractor \n",
+        "    feature_extractor_config = CLIPConfig.from_pretrained(self._feature_extractor_config_path)\n",
+        "    feature_extractor = CLIPFeatureExtractor(feature_extractor_config)\n",
+        "\n",
+        "    # Initialize the CLIP tokenizer\n",
+        "    tokenizer = CLIPTokenizer(self._tokenizer_vocab_config_path,\n",
+        "                              self._merges_file_config_path)\n",
+        "    \n",
+        "    # Initialize the CLIP processor used to process the image-caption pair \n",
+        "    self._processor = CLIPProcessor(feature_extractor=feature_extractor,\n",
+        "                                    tokenizer=tokenizer)\n",
+        "\n",
+        "  def process(self, element: Tuple[str, Dict[str, List[Any]]]):\n",
+        "\n",
+        "    image_url, image_captions_pair = element \n",
+        "    # Unpack the image and captions after grouping them with 'CoGroupByKey()' \n",
+        "    image = image_captions_pair['image'][0]\n",
+        "    captions = image_captions_pair['captions'][0]\n",
+        "    preprocessed_clip_input = self._processor(images = image,\n",
+        "                                              text = captions,\n",
+        "                                              return_tensors=\"pt\",\n",
+        "                                              padding=True)\n",
+        "    \n",
+        "    image_url_caption_pair = (image_url, captions)\n",
+        "    return [(image_url_caption_pair, preprocessed_clip_input)]\n",
+        "\n",
+        "\n",
+        "class RankCLIPOutput(beam.DoFn):\n",
+        "  \"\"\"\n",
+        "  Process the output of CLIP to get the captions sorted by ranking order.\n",
+        "\n",
+        "  The logits are the output of the CLIP model. Here, we apply a softmax activation\n",
+        "  function to the logits to get the probabilistic distribution of the relevance\n",
+        "  of each caption to the target image. After that, we sort the captions in descending\n",
+        "  order with respect to the probabilities as a caption-probability pair. \n",
+        "  \"\"\"\n",
+        "\n",
+        "  def process(self, element : Tuple[Tuple[str, List[str]], Iterable[PredictionResult]]):\n",
+        "    (image_url, captions), prediction = element\n",
+        "    prediction_results = prediction.inference\n",
+        "    prediction_probs = prediction_results.softmax(dim=-1).cpu().detach().numpy()\n",
+        "    ranking = np.argsort(-prediction_probs)\n",
+        "    sorted_caption_prob_pair = [(captions[idx], prediction_probs[idx]) for idx in ranking]\n",
+        "\n",
+        "    return [(image_url, sorted_caption_prob_pair)]"
+      ],
+      "metadata": {
+        "id": "vS2D6VRqBH28"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Model handlers"
+      ],
+      "metadata": {
+        "id": "W1pwQk6ozzZ0"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "> ℹ️ Note that we will use a `KeyedModelHandler` for both models to attach a key to the general `ModelHandler`. The key is used to keep a reference of which image the inference is assoicated with, and it used in our post processing steps. In our case, we're using the `image_url` as the key."
+      ],
+      "metadata": {
+        "id": "BTmSPnjj8M2m"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "class PytorchNoBatchModelHandlerKeyedTensor(PytorchModelHandlerKeyedTensor):\n",
+        "      \"\"\"Wrapper to PytorchModelHandler to limit batch size to 1.\n",
+        "    The caption strings generated from BLIP tokenizer may have different\n",
+        "    lengths, which doesn't work with torch.stack() in current RunInference\n",
+        "    implementation since stack() requires tensors to be the same size.\n",
+        "    Restricting max_batch_size to 1 means there is only 1 example per `batch`\n",
+        "    in the run_inference() call.\n",
+        "    \"\"\"\n",
+        "    # This should be changed when this ticket is resolved https://github.com/apache/beam/issues/21863\n",
+        "      def batch_elements_kwargs(self):\n",
+        "          return {'max_batch_size': 1}"
+      ],
+      "metadata": {
+        "id": "OaR02_wxTMpc"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "> ℹ️ Note that we will use a `KeyedModelHandler` for both models to attach a key to the general `ModelHandler`. The key will be used for aggregation transforms of different inputs. "
+      ],
+      "metadata": {
+        "id": "gNLRO0EwvcGP"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## BLIP"
+      ],
+      "metadata": {
+        "id": "OXz7TuK4W_ZN"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "MAX_CAPTION_LENGTH = 80\n",
+        "MIN_CAPTION_LENGTH = 10\n",
+        "# Increasing beam search can improve the quality of the captions but results in\n",
+        "# more compute time\n",
+        "NUM_BEAMS = 1\n"
+      ],
+      "metadata": {
+        "id": "0npmJ7uSZN7w"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "class BLIPWrapper(torch.nn.Module):\n",
+        "  \"\"\"\n",
+        "   Wrapper around the BLIP model to overwrite the default \"forward\" method with the \"generate\" since BLIP uses the \n",
+        "  \"generate\" method to produce the image captions.\n",
+        "  \"\"\"\n",
+        "  \n",
+        "  def __init__(self, base_model: blip_decoder, num_beams: int, max_length: int,\n",
+        "                min_length: int):\n",
+        "    super().__init__()\n",
+        "    self._model = base_model()\n",
+        "    self._num_beams = num_beams\n",
+        "    self._max_length = max_length\n",
+        "    self._min_length = min_length\n",
+        "\n",
+        "  def forward(self, inputs: torch.Tensor):\n",
+        "    # squeeze because RunInference adds an extra dimension, which is empty\n",
+        "    # This should be changed when this ticket is resolved https://github.com/apache/beam/issues/21863\n",
+        "    inputs = inputs.squeeze(0)\n",
+        "    captions = self._model.generate(inputs,\n",
+        "                                    sample=True,\n",
+        "                                    num_beams=self._num_beams,\n",
+        "                                    max_length=self._max_length,\n",
+        "                                    min_length=self._min_length)\n",
+        "    return [captions]\n",
+        "\n",
+        "  def load_state_dict(self, state_dict: dict):\n",
+        "    self._model.load_state_dict(state_dict)\n",
+        "\n",
+        "\n",
+        "BLIP_model_handler = PytorchNoBatchModelHandlerKeyedTensor(\n",
+        "    state_dict_path=blip_state_dict_path,\n",
+        "    model_class=BLIPWrapper,\n",
+        "    model_params={'base_model': blip_decoder, 'num_beams': NUM_BEAMS,\n",
+        "                  'max_length': MAX_CAPTION_LENGTH, 'min_length': MIN_CAPTION_LENGTH},\n",
+        "    device='GPU')\n",
+        "\n",
+        "BLIP_keyed_model_handler = KeyedModelHandler(BLIP_model_handler)"
+      ],
+      "metadata": {
+        "id": "RCKBJjujVw4q"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## CLIP"
+      ],
+      "metadata": {
+        "id": "-8PG_0txMiYA"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "class CLIPWrapper(CLIPModel):\n",
+        "\n",
+        "  def forward(self, **kwargs: Dict[str, torch.Tensor]):\n",
+        "    # squeeze because RunInference adds an extra dimension, which is empty\n",
+        "    # This should be changed when this ticket is resolved https://github.com/apache/beam/issues/21863\n",
+        "    kwargs = {key: tensor.squeeze(0) for key, tensor in kwargs.items()}\n",
+        "    output = super().forward(**kwargs)\n",
+        "    logits = output.logits_per_image\n",
+        "    return logits\n",
+        "\n",
+        "\n",
+        "CLIP_model_handler = PytorchNoBatchModelHandlerKeyedTensor(\n",
+        "    state_dict_path=clip_state_dict_path,\n",
+        "    model_class=CLIPWrapper,\n",
+        "    model_params={'config': CLIPConfig.from_pretrained(clip_model_config_path)},\n",
+        "    device='GPU')\n",
+        "\n",
+        "CLIP_keyed_model_handler = KeyedModelHandler(CLIP_model_handler)\n"
+      ],
+      "metadata": {
+        "id": "EJw_OnZ1ZfuH"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Specify the images to display"
+      ],
+      "metadata": {
+        "id": "azC12uqDn0bq"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "images_url = ['https://storage.googleapis.com/apache-beam-samples/image_captioning/hamster_tea.jpg',\n",
+        "              'https://storage.googleapis.com/apache-beam-samples/image_captioning/potato_field.jpg',\n",
+        "              'https://storage.googleapis.com/apache-beam-samples/image_captioning/eiffel_tower_ballet_dancer.jpg']"
+      ],
+      "metadata": {
+        "id": "VJwE0bquoXOf"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Let's visualize the images that we will use for captioning "
+      ],
+      "metadata": {
+        "id": "c3fpgc15hzcq"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "for image_url in images_url:\n",
+        "  response = requests.get(image_url)\n",
+        "  image = Image.open(BytesIO(response.content)).convert('RGB')\n",
+        "  fig = plt.figure()\n",
+        "  title = os.path.basename(image_url).rsplit('.')[0]\n",
+        "  fig.suptitle(title, fontsize=12)\n",
+        "  plt.axis('off')\n",
+        "  plt.imshow(image)\n"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 797
+        },
+        "id": "qTDcWfND9r1Y",
+        "outputId": "f0b096f8-a228-4b74-d335-c9fdb89746f3"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<Figure size 432x288 with 1 Axes>"
+            ],
+            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAEECAYAAAB3HMxCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOy9aawkyXbf9zsRkVlV996+3T3rm5k3b+G+iYsMigZsUYBt2LQsw98sAV7gD4Y/GYbgFSBgmDZkGbBhyxZkCJAsSzakD4IWCKBtElpoAxREgSYEk5RNkXzvccR5y7zZevpuVZUZEccfIrKWzLyVVXfp7nkv/4PpW5WVGREZGXninP85cUJUlREjRowYcXuY592AESNGjPhWwShQR4wYMeKOMArUESNGjLgjjAJ1xIgRI+4Io0AdMWLEiDvCKFBHjBgx4o4wCtQRI0 [...]
+          },
+          "metadata": {
+            "needs_background": "light"
+          }
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<Figure size 432x288 with 1 Axes>"
+            ],
+            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAEECAYAAAB3HMxCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOy9W6wtS3ff9RtV1T3nXGvvfc7Z57vZjv194CRI8IJ4ASICEQJxSSIegISrEFJASEi8ICIShPBDAjyAFMAvCARBsYgABSFkRRAiYRmBQJBIji0H29/9nO/c9tmXdZuX7qoaPIxR3T3nXmuf73KE7XgOae09Lz27q6tGjfEf1xZV5UxnOtOZzvSjU/jNHsCZznSmM/2NQmeBeqYznelMnxOdBeqZznSmM31OdBaoZzrTmc70OdFZoJ7pTGc60+dEZ4F6pjOd6UyfE5 [...]
+          },
+          "metadata": {
+            "needs_background": "light"
+          }
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<Figure size 432x288 with 1 Axes>"
+            ],
+            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAAEECAYAAADTUyO4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOy9e7BvyXXX91nd+/c75z5mrjQPyRppJGNZlmzZPIyEy4aAeYWExElcFIQEkkCgKn+ESlUqFBQpAq6EAKmCokioxASncAFJUeZVBSmCEcFWGRuwsbELHMuWrJE0mpFG87pzn+f8frt75Y+1VnfvfX7n3DuSXbqY0zPnnvP77b1792M9vmv16tWiqlyWy3JZHr2SvtwNuCyX5bIcLpfMeVkuyyNaLpnzslyWR7RcMudluSyPaLlkzstyWR7Rcsmcl+WyPKLlkjkvy2 [...]
+          },
+          "metadata": {
+            "needs_background": "light"
+          }
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Initialize pipeline run parameters "
+      ],
+      "metadata": {
+        "id": "m8S8VQHvoEZf"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Number of Generated captions per image\n",
+        "NUM_CAPTIONS_PER_IMAGE = 10\n",
+        "\n",
+        "# Top captions to display\n",
+        "NUM_TOP_CAPTIONS_TO_DISPLAY = 3\n"
+      ],
+      "metadata": {
+        "id": "Dcz_M9GW0Kan"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Run pipeline"
+      ],
+      "metadata": {
+        "id": "5T9Pcdp7oNb8"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "> ℹ️ Note that we are using raw images from the `read_images` pipeline as input to both models. This is done because each model needs to preprocess the raw images differently (i.e. they require a different embedding representation for image captioning and image/captions pair ranking resp.).\n",
+        "\n",
+        "> ℹ️ We use `CoGroupByKey` to aggregate the raw images with the generated captions by their key (i.e. the image url). This process produces a tuple of image-captions pairs that is then passed to the CLIP transform and used for ranking."
+      ],
+      "metadata": {
+        "id": "G4a2ACIYeJyj"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "with beam.Pipeline() as pipeline:\n",
+        "\n",
+        "  read_images = (\n",
+        "            pipeline \n",
+        "            | \"ReadUrl\" >> beam.Create(images_url)\n",
+        "            | \"ReadImages\" >> beam.ParDo(ReadImagesFromUrl()))\n",
+        "\n",
+        "  blip_caption_generation = (\n",
+        "            read_images\n",
+        "            | \"PreprocessBlipInput\" >> beam.ParDo(PreprocessBLIPInput(NUM_CAPTIONS_PER_IMAGE)) \n",
+        "            | \"GenerateCaptions\" >> RunInference(BLIP_keyed_model_handler)\n",
+        "            | \"PostprocessCaptions\" >> beam.ParDo(PostprocessBLIPOutput()))\n",
+        "\n",
+        "  clip_captions_ranking = (\n",
+        "            ({'image' : read_images, 'captions': blip_caption_generation})\n",
+        "            | \"CreateImageCaptionPair\" >> beam.CoGroupByKey()\n",
+        "            | \"PreprocessClipInput\" >> beam.ParDo(\n",
+        "                PreprocessCLIPInput(\n",
+        "                    clip_feature_extractor_config_path,\n",
+        "                    clip_tokenizer_vocab_config_path,\n",
+        "                    clip_merges_config_path))\n",
+        "            | \"GetRankingLogits\" >> RunInference(CLIP_keyed_model_handler)\n",
+        "            | \"RankClipOutput\" >> beam.ParDo(RankCLIPOutput()))\n",
+        "\n",
+        "  clip_captions_ranking | \"FormatCaptions\" >> beam.ParDo(FormatCaptions(NUM_TOP_CAPTIONS_TO_DISPLAY))\n",
+        "  "
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "002e-FNbmuB8",
+        "outputId": "43e260f0-8419-447e-888b-c61e500ef391"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Image: hamster_tea\n",
+            "\tTop 3 captions ranked by CLIP:\n",
+            "\t\t1: a small white hamster inside of a mug. (Caption probability: 0.31)\n",
+            "\t\t2: the hamster is in a white mug on a blue surface. (Caption probability: 0.24)\n",
+            "\t\t3: a small hamster in a coffee cup. (Caption probability: 0.22)\n",
+            "\n",
+            "\n",
+            "Image: potato_field\n",
+            "\tTop 3 captions ranked by CLIP:\n",
+            "\t\t1: several potatoes are in the ground with a blue sky. (Caption probability: 0.64)\n",
+            "\t\t2: a bunch of potatoes are sitting in the dirt. (Caption probability: 0.09)\n",
+            "\t\t3: potato plants sprouts on muddy ground near forest. (Caption probability: 0.09)\n",
+            "\n",
+            "\n",
+            "Image: eiffel_tower_ballet_dancer\n",
+            "\tTop 3 captions ranked by CLIP:\n",
+            "\t\t1: a woman practices ballet in front of the eiffel tower. (Caption probability: 0.48)\n",
+            "\t\t2: a dancer is practicing in front of the eiffel tower. (Caption probability: 0.26)\n",
+            "\t\t3: a woman is doing ballet next to the eiffel tower. (Caption probability: 0.17)\n",
+            "\n",
+            "\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# References\n",
+        "\n",
+        "* [RunInference API](https://beam.apache.org/documentation/sdks/python-machine-learning/) -- an official guide to the RunInference API.\n",
+        "* [RunInference Demo](https://colab.research.google.com/drive/10iPQTCmaLJL4_OohS00R9Wmor6d57JkS#scrollTo=ZVtBsKDgW1dl) -- a demo on ensemble model in colab\n",
+        "* [The advantages of having a DAG and what it unlocks for you](https://beam.apache.org/documentation/dsls/dataframes/differences-from-pandas) -- A guide on the advantages of using a Beam DAG for ML workflow orchestration and inference. "
+      ],
+      "metadata": {
+        "id": "HMH_ldJsrJoz"
+      }
+    }
+  ]
+}
\ No newline at end of file
diff --git a/website/www/site/content/en/documentation/ml/multi-model-pipelines.md b/website/www/site/content/en/documentation/ml/multi-model-pipelines.md
new file mode 100644
index 00000000000..be614e4b500
--- /dev/null
+++ b/website/www/site/content/en/documentation/ml/multi-model-pipelines.md
@@ -0,0 +1,101 @@
+---
+title: "Multi-model pipelines"
+---
+<!--
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+# Multi-model pipelines
+
+Apache Beam allows you to develop multi-model pipelines. In this specific scenario, you can ingest
+and transform some input data, run it through a model, and then pass the outcome of your first model
+into a second model. This page explains how multi-model pipelines work and gives an overview of what
+you need to know to build one.
+
+Before reading this section, it is recommended that you become familiar with the information in
+the [Pipeline development lifecycle](https://beam.apache.org/documentation/pipelines/design-your-pipeline/)
+.
+
+## How to build a Multi-model pipeline with Beam
+
+A typical machine learning workflow involves a series of data transformation steps such as data
+ingestion, data processing tasks, inference, and post-processing. Beam enables you to orchestrate
+all of those steps together by encapsulating them in a single Beam DAG. This allows you to build
+resilient and scalable end-to-end machine learning systems.
+
+To deploy your machine learning model in a Beam pipeline, you can use
+the [`RunInferenceAPI`](https://beam.apache.org/documentation/sdks/python-machine-learning/) which
+facilitates the integration of your model as a `PTransform` step in your DAG. Composing
+multiple `RunInference` transforms within a single DAG allows us to build a pipeline that consists
+of multiple ML models. This way Beam supports the development of complex ML systems.
+
+There are different patterns that can be used to build multi-model pipelines in Beam. Let’s have a
+look at a few of them.
+
+### A/B Pattern
+
+The A/B pattern describes a framework multiple where ML models are running in parallel. One
+application for this pattern is to test the performance of different machine learning models and
+decide whether a new model is an improvement over an existing one. This is also known as the
+“Champion/Challenger” method. Here, we typically define a business metric to compare the performance
+of a control model with the current model.
+
+An example could be recommendation engine models where you have an existing model that recommends
+ads based on the user’s preferences and activity history. When deciding to deploy a new model, you
+could split the incoming user traffic into two branches where half of the users are exposed to the
+new model and the other half to the current one.
+
+Afterwards, you could then measure the average click-through rate (CTR) of ads for both sets of
+users over a defined period of time to determine if the new model is performing better than the
+existing one.
+
+```
+import apache_beam as beam
+
+with beam.Pipeline() as pipeline:
+   userset_a_traffic, userset_b_traffic =
+     (pipeline | 'ReadFromStream' >> beam.ReadFromStream('stream_source')
+               | ‘Partition’ >> beam.partition(split_dataset, 2, ratio=[5, 5])
+     )
+
+model_a_predictions = userset_a_traffic | RunInference(<model_handler_A>)
+model_b_predictions = userset_b_traffic | RunInference(<model_handler_B>)
+```
+
+Where `beam.partition` is used to split the data source into 50/50 split partitions. For more
+information on data partitioning,
+see [Partition](https://beam.apache.org/documentation/transforms/python/elementwise/partition/).
+
+### Cascade Pattern
+
+The Cascade pattern is used to solve use-cases where the solution involves a series of ML models. In
+this scenario, the output of a model is typically transformed to a suitable format using
+a `PTransform` before passing it to another model.
+
+```
+with pipeline as p:
+   data = p | 'Read' >> beam.ReadFromSource('a_source')
+   model_a_predictions = data | RunInference(<model_handler_A>)
+   model_b_predictions = model_a_predictions | beam.ParDo(post_processing()) | RunInference(<model_handler_B>)
+```
+
+In
+this [notebook](https://github.com/apache/beam/tree/master/examples/notebooks/beam-ml/run-inference-multi-model.ipynb)
+, we show an end-to-end example of a cascade pipeline used for generating and ranking image
+captions. The solution consists of two open-source models:
+
+1. **A caption generation model ([BLIP](https://github.com/salesforce/BLIP))** that generates
+   candidate image captions from an input image.
+2. **A caption ranking model ([CLIP](https://github.com/openai/CLIP))** that uses the image and
+   candidate captions to rank the captions in the order in which they best describe the image.
+
diff --git a/website/www/site/content/en/documentation/ml/overview.md b/website/www/site/content/en/documentation/ml/overview.md
index 1c6e4e03fd0..0a0a83961b9 100755
--- a/website/www/site/content/en/documentation/ml/overview.md
+++ b/website/www/site/content/en/documentation/ml/overview.md
@@ -58,4 +58,5 @@ In order to automate and track the AI/ML workflows throughout your project, you
 ## Examples
 
 You can find examples of end-to-end AI/ML pipelines for several use cases:
+* [Multi model pipelines in Beam](/documentation/ml/multi-model-pipelines)
 * [Online Clustering in Beam](/documentation/ml/online-clustering)
diff --git a/website/www/site/content/en/documentation/sdks/python-machine-learning.md b/website/www/site/content/en/documentation/sdks/python-machine-learning.md
index cce9853990e..bcd430d0072 100644
--- a/website/www/site/content/en/documentation/sdks/python-machine-learning.md
+++ b/website/www/site/content/en/documentation/sdks/python-machine-learning.md
@@ -41,7 +41,7 @@ Using the `Shared` class within the RunInference implementation makes it possibl
 
 ### Multi-model pipelines
 
-The RunInference API can be composed into multi-model pipelines. Multi-model pipelines can be useful for A/B testing or for building out ensembles made up of models that perform tokenization, sentence segmentation, part-of-speech tagging, named entity extraction, language detection, coreference resolution, and more.
+The RunInference API can be composed into multi-model pipelines. Multi-model pipelines can be useful for A/B testing or for building out cascade models made up of models that perform tokenization, sentence segmentation, part-of-speech tagging, named entity extraction, language detection, coreference resolution, and more.
 
 ## Modify a pipeline to use an ML model
 
@@ -107,7 +107,7 @@ with pipeline as p:
 
 Where `model_handler_A` and `model_handler_B` are the model handler setup code.
 
-#### Ensemble Pattern
+#### Cascade Pattern
 
 ```
 with pipeline as p:
@@ -124,7 +124,7 @@ When using multiple models in a single pipeline, different models may have diffe
 Resource hints allow you to provide information to a runner about the compute resource requirements for each step in your
 pipeline.
 
-For example, the following snippet extends the previous ensemble pattern with hints for each RunInference call
+For example, the following snippet extends the previous cascade pattern with hints for each RunInference call
 to specify RAM and hardware accelerator requirements:
 
 ```
diff --git a/website/www/site/layouts/partials/section-menu/en/documentation.html b/website/www/site/layouts/partials/section-menu/en/documentation.html
index f55e30f6c4d..ebfc97f3c12 100644
--- a/website/www/site/layouts/partials/section-menu/en/documentation.html
+++ b/website/www/site/layouts/partials/section-menu/en/documentation.html
@@ -211,8 +211,10 @@
 </li>
 <li class="section-nav-item--collapsible">
   <span class="section-nav-list-title">AI/ML pipelines</span>
+
   <ul class="section-nav-list">
     <li><a href="/documentation/ml/overview/">Overview</a></li>
+    <li><a href="/documentation/ml/multi-model-pipelines/">Multi-model pipelines</a></li>
     <li><a href="/documentation/ml/online-clustering/">Online Clustering</a></li>
   </ul>
 </li>