You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2022/10/06 17:14:05 UTC
[beam] branch master updated: Content/multi model pipelines (#23498)
This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 9b2f87d0201 Content/multi model pipelines (#23498)
9b2f87d0201 is described below
commit 9b2f87d0201e8923a021c6bedbe0f64e37704014
Author: Philippe Moussalli <ph...@gmail.com>
AuthorDate: Thu Oct 6 19:13:57 2022 +0200
Content/multi model pipelines (#23498)
* Created using Colaboratory
* add new website content for mulit-model pipelines
* add example notebook with blip/clip
* add multi-model website content
* add multi-model example notebook
* fix notebook typo and link issues
* add reference to multimodel page in beam-ml master page
* remove whitespaces
* fix wrong resource link
* change references of ensemble models to cascade models
Co-authored-by: Philippe Moussalli <ph...@ml6.eu>
---
.../beam-ml/run_inference_multi_model.ipynb | 994 +++++++++++++++++++++
.../en/documentation/ml/multi-model-pipelines.md | 101 +++
.../site/content/en/documentation/ml/overview.md | 1 +
.../documentation/sdks/python-machine-learning.md | 6 +-
.../partials/section-menu/en/documentation.html | 2 +
5 files changed, 1101 insertions(+), 3 deletions(-)
diff --git a/examples/notebooks/beam-ml/run_inference_multi_model.ipynb b/examples/notebooks/beam-ml/run_inference_multi_model.ipynb
new file mode 100644
index 00000000000..65624eb74b1
--- /dev/null
+++ b/examples/notebooks/beam-ml/run_inference_multi_model.ipynb
@@ -0,0 +1,994 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "collapsed_sections": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "code",
+ "source": [
+ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
+ "\n",
+ "# Licensed to the Apache Software Foundation (ASF) under one\n",
+ "# or more contributor license agreements. See the NOTICE file\n",
+ "# distributed with this work for additional information\n",
+ "# regarding copyright ownership. The ASF licenses this file\n",
+ "# to you under the Apache License, Version 2.0 (the\n",
+ "# \"License\"); you may not use this file except in compliance\n",
+ "# with the License. You may obtain a copy of the License at\n",
+ "#\n",
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing,\n",
+ "# software distributed under the License is distributed on an\n",
+ "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
+ "# KIND, either express or implied. See the License for the\n",
+ "# specific language governing permissions and limitations\n",
+ "# under the License"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "L_6L5GU7jyR_"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Ensemble model using an image captioning and ranking example"
+ ],
+ "metadata": {
+ "id": "gPCMXWgOMt_0"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "A single machine learning model may not always be the perfect solution for a give task. Oftentimes, machine learning model tasks involve aggregating mutliple models together to produce one optimal predictive model and boost performance. \n",
+ " \n",
+ "\n",
+ "In this notebook, we will shows you an example on how to implement a cascade model in Beam using the [RunInference API](https://beam.apache.org/documentation/sdks/python-machine-learning/). The RunInference API enables you to run your Beam transfroms as part of your pipeline for optimal machine learning inference in beam. \n",
+ "\n",
+ "Make sure to checkout this [notebook](https://colab.research.google.com/drive/111USL4VhUa0xt_mKJxl5nC1YLOC8_yF4?usp=sharing#scrollTo=746b67a7-3562-467f-bea3-d8cd18c14927) to get familiar with the RunInference API."
+ ],
+ "metadata": {
+ "id": "6vZWSLyuM_P4"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Use case: Image captioning with cascade models "
+ ],
+ "metadata": {
+ "id": "i1uyzlj3s3e_"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Image captioning has various different applications such as image indexing for information retreival, usage in virtual assistants and many other natural language processing applications.\n",
+ "\n",
+ "We want to use beam to generate captions on a a large set of images. Beam is the ideal tool to handle this. We will use two models for this task:\n",
+ "\n",
+ "* [BLIP](https://github.com/salesforce/BLIP): Used to generate a set of candidate captions for a given image. \n",
+ "* [CLIP](https://github.com/openai/CLIP): Used to rank the generated captions by the order in which they better represent the the given image."
+ ],
+ "metadata": {
+ "id": "cP1sBhNacS8b"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "![image_captioning_example.png](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABNMAAAFqCAYAAAAwbBxSAAAABHNCSVQICAgIfAhkiAAAABl0RVh0U29mdHdhcmUAZ25vbWUtc2NyZWVuc2hvdO8Dvz4AACAASURBVHic7L3LsiRZsiW0dD/M3M/JzFvV1Y/bTUsDPWHCY8CACb/EP9w/4FcYIkyZIExAEEQAkW6EFoGmabnQt25mHDfbL2WwVLe5nziRcSIys7Kqy1QkKuu4m9tz722qS5cuFVVVnHbaaaeddtppp5122mmnnXbaaaeddtppn7Xwa5/Aaaeddtppp5122mmnnXbaaaeddtppp/2p2AmmnXbaaaeddtppp5122mmnnXbaaaeddto77QTTTjvttNNOO+2000477bTTTjvttNNOO+2ddoJpp5122mmnnXbaaaeddtppp5122 [...]
+ ],
+ "metadata": {
+ "id": "PT8L9hbTcytZ"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "The steps needed to build this pipeline can be summarized as follows:\n",
+ "* Read the images.\n",
+ "* Preprocess the images for caption generation for inference with the BLIP model.\n",
+ "* Inference with BLIP to generate a list of caption candidates . \n",
+ "* Aggregate the generated captions with their source image.\n",
+ "* Preprocess the aggregated image-caption pair to rank them with CLIP.\n",
+ "* Inference wih CLIP to generated the caption ranking. \n",
+ "* Print the image names and the captions sorted according to their ranking\n",
+ "\n",
+ "\n",
+ "The following image illustrates the steps that will be followed in the inference pipelines in more details:"
+ ],
+ "metadata": {
+ "id": "lBPfy-bYgLuD"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title Diagram\n",
+ "from IPython.display import Image\n",
+ "Image(url='https://storage.googleapis.com/apache-beam-samples/image_captioning/beam_ensemble_diagram.png', width=2000)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 440
+ },
+ "id": "3suC5woJLW_N",
+ "outputId": "d2f9f67b-361b-4ae9-f9db-ce2ff9abd509",
+ "cellView": "form"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/html": [
+ "<img src=\"https://storage.googleapis.com/apache-beam-samples/image_captioning/beam_ensemble_diagram.png\" width=\"2000\"/>"
+ ],
+ "text/plain": [
+ "<IPython.core.display.Image object>"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 3
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Dependencies"
+ ],
+ "metadata": {
+ "id": "GULu36WYx5MB"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "The RunInference library is available in Apache Beam version **2.40** or later. "
+ ],
+ "metadata": {
+ "id": "E0uy4-nWNdBa"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "tTUZpG9_q-OW",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "14ffb9db-b67a-40bb-bf39-3b472b188898"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\u001b[K |████████████████████████████████| 2.0 MB 4.7 MB/s \n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m34.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m18.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m880.6/880.6 kB\u001b[0m \u001b[31m62.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.3/3.3 MB\u001b[0m \u001b[31m75.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m377.0/377.0 kB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.1/53.1 kB\u001b[0m \u001b[31m2.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m235.4/235.4 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ " Building wheel for fairscale (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+ "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+ "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+ "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+ "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+ "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+ "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+ "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+ "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+ "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+ "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+ "\u001b[0m\u001b[33mWARNING: google-api-core 2.8.2 does not provide the extra 'grpcgcp'\u001b[0m\u001b[33m\n",
+ "\u001b[0m\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "tensorflow 2.8.2+zzzcolab20220719082949 requires protobuf<3.20,>=3.9.2, but you have protobuf 3.20.2 which is incompatible.\u001b[0m\u001b[31m\n",
+ "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install --upgrade pip --quiet\n",
+ "!pip install transformers==4.15.0 --quiet\n",
+ "!pip install timm==0.4.12 --quiet\n",
+ "!pip install ftfy==6.1.1 --quiet\n",
+ "!pip install spacy==3.4.1 --quiet\n",
+ "!pip install fairscale==0.4.4 --quiet\n",
+ "!pip install apache_beam[gcp]>=2.40.0 \n",
+ "\n",
+ "# restart the runtime in order to use newly installed versions\n",
+ "exit() "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import requests\n",
+ "import os \n",
+ "import io\n",
+ "from io import BytesIO\n",
+ "from typing import Iterator\n",
+ "from typing import Iterable\n",
+ "from typing import Tuple\n",
+ "from typing import Optional\n",
+ "from typing import Dict\n",
+ "from typing import List\n",
+ "from typing import Any\n",
+ "\n",
+ "import apache_beam as beam\n",
+ "from apache_beam.ml.inference.base import PredictionResult\n",
+ "from apache_beam.options.pipeline_options import PipelineOptions\n",
+ "from apache_beam.options.pipeline_options import SetupOptions\n",
+ "from apache_beam.ml.inference.base import KeyedModelHandler\n",
+ "from apache_beam.ml.inference.base import PredictionResult\n",
+ "from apache_beam.ml.inference.base import RunInference\n",
+ "from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor\n",
+ "from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor\n",
+ "from transformers import CLIPProcessor\n",
+ "from transformers import CLIPTokenizer\n",
+ "from transformers import CLIPModel\n",
+ "from transformers import CLIPConfig\n",
+ "from transformers import CLIPFeatureExtractor\n",
+ "import torch\n",
+ "from torchvision import transforms\n",
+ "from torchvision.transforms.functional import InterpolationMode\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "from PIL import Image"
+ ],
+ "metadata": {
+ "id": "6JbQqfYuvx1f"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### CLIP"
+ ],
+ "metadata": {
+ "id": "iMsN4vUXilTg"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "! git lfs install\n",
+ "! git clone https://huggingface.co/openai/clip-vit-base-patch32"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Ud4sUXV2x8LO",
+ "outputId": "aa7d593c-107d-40ab-9212-2aef61fa64a7"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Error: Failed to call git rev-parse --git-dir --show-toplevel: \"fatal: not a git repository (or any of the parent directories): .git\\n\"\n",
+ "Git LFS initialized.\n",
+ "Cloning into 'clip-vit-base-patch32'...\n",
+ "remote: Enumerating objects: 48, done.\u001b[K\n",
+ "remote: Counting objects: 100% (48/48), done.\u001b[K\n",
+ "remote: Compressing objects: 100% (28/28), done.\u001b[K\n",
+ "remote: Total 48 (delta 22), reused 42 (delta 19), pack-reused 0\u001b[K\n",
+ "Unpacking objects: 100% (48/48), done.\n",
+ "Filtering content: 100% (3/3), 1.69 GiB | 91.86 MiB/s, done.\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# CLIP model and component configs paths\n",
+ "clip_feature_extractor_config_path = '/content/clip-vit-base-patch32/preprocessor_config.json'\n",
+ "clip_tokenizer_vocab_config_path = '/content/clip-vit-base-patch32/vocab.json'\n",
+ "clip_merges_config_path = '/content/clip-vit-base-patch32/merges.txt'\n",
+ "clip_model_config_path = '/content/clip-vit-base-patch32/config.json'\n",
+ "clip_state_dict_path = '/content/clip-vit-base-patch32/pytorch_model.bin'\n"
+ ],
+ "metadata": {
+ "id": "cDbmVBKuZoWE"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### BLIP"
+ ],
+ "metadata": {
+ "id": "Rg9mKAWnR8s4"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!git clone https://github.com/salesforce/BLIP\n",
+ "%cd /content/BLIP"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "g4-6WwqUtxea",
+ "outputId": "9a1f2ebf-9fb6-4621-a4cd-7470b5d24964"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Cloning into 'BLIP'...\n",
+ "remote: Enumerating objects: 274, done.\u001b[K\n",
+ "remote: Counting objects: 100% (109/109), done.\u001b[K\n",
+ "remote: Compressing objects: 100% (29/29), done.\u001b[K\n",
+ "remote: Total 274 (delta 89), reused 80 (delta 80), pack-reused 165\u001b[K\n",
+ "Receiving objects: 100% (274/274), 7.67 MiB | 26.26 MiB/s, done.\n",
+ "Resolving deltas: 100% (146/146), done.\n",
+ "/content/BLIP\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from BLIP.models.blip import blip_decoder\n",
+ "\n",
+ "!gdown 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'\n",
+ "# The blip model is saved as a checkoint, load it and save it as a state dict since RunInference required \n",
+ "# a state dict for model instantiation \n",
+ "blip_state_dict_path = '/content/BLIP/blip_state_dict.pth'\n",
+ "torch.save(torch.load('/content/BLIP/model*_base_caption.pth')['model'], blip_state_dict_path)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "GCvOP_iZh41c",
+ "outputId": "1e4779dc-ed0c-4450-8590-4d9caeab1083"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Downloading...\n",
+ "From: https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth\n",
+ "To: /content/BLIP/model*_base_caption.pth\n",
+ "100% 896M/896M [00:04<00:00, 193MB/s] \n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## I/O helper functions "
+ ],
+ "metadata": {
+ "id": "FGHgvycOyicj"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "class ReadImagesFromUrl(beam.DoFn):\n",
+ " \"\"\"\n",
+ " Read an image from a given url and return a tuple of the images_url\n",
+ " and image data.\n",
+ " \"\"\"\n",
+ " def process(self, element: str) -> Tuple[str, Image.Image]:\n",
+ " response = requests.get(element)\n",
+ " image = Image.open(BytesIO(response.content)).convert('RGB')\n",
+ " return [(element, image)]\n",
+ "\n",
+ "\n",
+ "class FormatCaptions(beam.DoFn):\n",
+ " \"\"\"\n",
+ " Print the image name and it's most relevant captions after CLIP ranking.\n",
+ " \"\"\"\n",
+ " def __init__(self, number_of_top_captions: int):\n",
+ " self._number_of_top_captions = number_of_top_captions\n",
+ "\n",
+ " def process(self, element: Tuple[str, List[str]]):\n",
+ " image_url, caption_list = element\n",
+ " caption_list = caption_list[:self._number_of_top_captions]\n",
+ " img_name = os.path.basename(image_url).rsplit('.')[0]\n",
+ " print(f'Image: {img_name}')\n",
+ " print(f'\\tTop {self._number_of_top_captions} captions ranked by CLIP:')\n",
+ " for caption_rank, caption_prob_pair in enumerate(caption_list):\n",
+ " print(f'\\t\\t{caption_rank+1}: {caption_prob_pair[0]}. (Caption probability: {caption_prob_pair[1]:.2f})')\n",
+ " print('\\n')"
+ ],
+ "metadata": {
+ "id": "1Lz3yGuqlAJ_"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Intermediate Processing functions"
+ ],
+ "metadata": {
+ "id": "ogJBy2kfWo6i"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Here we define the preprocessing and postprocessing function for each of the models.\n",
+ "\n",
+ "> ℹ️ We use `DoFn.setup()` to prepare the instance for processing bundles of elements by initializing and cache the processing transform resources. As such, we avoid unnecessary re-initializations on every invocation to the processing method."
+ ],
+ "metadata": {
+ "id": "wEViP715fes4"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### BLIP"
+ ],
+ "metadata": {
+ "id": "X1UGv6bbyNxY"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "class PreprocessBLIPInput(beam.DoFn):\n",
+ "\n",
+ " \"\"\"\n",
+ " Process the raw image input to a format suitable for BLIP Inference. The processed\n",
+ " images are duplicated to the number of desired captions per image. \n",
+ "\n",
+ " Preprocessing transformation taken from: \n",
+ " https://github.com/salesforce/BLIP/blob/d10be550b2974e17ea72e74edc7948c9e5eab884/predict.py\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, captions_per_image: int):\n",
+ " self._captions_per_image = captions_per_image\n",
+ "\n",
+ " def setup(self):\n",
+ " \n",
+ " # Initialize image transformer\n",
+ " self._transform = transforms.Compose([\n",
+ " transforms.Resize((384, 384),interpolation=InterpolationMode.BICUBIC),\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
+ " ])\n",
+ "\n",
+ " def process(self, element):\n",
+ " image_url, image = element \n",
+ " # This should be changed when this ticket is resolved https://github.com/apache/beam/issues/21863\n",
+ " preprocessed_img = self._transform(image).unsqueeze(0)\n",
+ " preprocessed_img = preprocessed_img.repeat(self._captions_per_image, 1, 1, 1)\n",
+ " # Parse the processed input to a dictionary to a format suitable for RunInference\n",
+ " preprocessed_dict = {'inputs': preprocessed_img}\n",
+ "\n",
+ " return [(image_url, preprocessed_dict)]\n",
+ "\n",
+ "class PostprocessBLIPOutput(beam.DoFn):\n",
+ " \"\"\"\n",
+ " Process the PredictionResult to get the generated image captions\n",
+ " \"\"\"\n",
+ " def process(self, element : Tuple[str, Iterable[PredictionResult]]):\n",
+ " image_url, prediction = element \n",
+ "\n",
+ " return [(image_url, prediction.inference)]"
+ ],
+ "metadata": {
+ "id": "A1s_QQoUctkc"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "###CLIP"
+ ],
+ "metadata": {
+ "id": "EZHfa1KzWWDI"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "class PreprocessCLIPInput(beam.DoFn):\n",
+ "\n",
+ " \"\"\"\n",
+ " Process the image-caption pair to a format suitable for CLIP inference. \n",
+ "\n",
+ " After grouping the raw images with the generated captions, we need to \n",
+ " preprocess them before passing them to the ranking stage (CLIP model).\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self,\n",
+ " feature_extractor_config_path: str,\n",
+ " tokenizer_vocab_config_path: str,\n",
+ " merges_file_config_path: str):\n",
+ "\n",
+ " self._feature_extractor_config_path = feature_extractor_config_path\n",
+ " self._tokenizer_vocab_config_path = tokenizer_vocab_config_path \n",
+ " self._merges_file_config_path = merges_file_config_path\n",
+ "\n",
+ "\n",
+ " def setup(self):\n",
+ " \n",
+ " # Initialize the CLIP feature extractor \n",
+ " feature_extractor_config = CLIPConfig.from_pretrained(self._feature_extractor_config_path)\n",
+ " feature_extractor = CLIPFeatureExtractor(feature_extractor_config)\n",
+ "\n",
+ " # Initialize the CLIP tokenizer\n",
+ " tokenizer = CLIPTokenizer(self._tokenizer_vocab_config_path,\n",
+ " self._merges_file_config_path)\n",
+ " \n",
+ " # Initialize the CLIP processor used to process the image-caption pair \n",
+ " self._processor = CLIPProcessor(feature_extractor=feature_extractor,\n",
+ " tokenizer=tokenizer)\n",
+ "\n",
+ " def process(self, element: Tuple[str, Dict[str, List[Any]]]):\n",
+ "\n",
+ " image_url, image_captions_pair = element \n",
+ " # Unpack the image and captions after grouping them with 'CoGroupByKey()' \n",
+ " image = image_captions_pair['image'][0]\n",
+ " captions = image_captions_pair['captions'][0]\n",
+ " preprocessed_clip_input = self._processor(images = image,\n",
+ " text = captions,\n",
+ " return_tensors=\"pt\",\n",
+ " padding=True)\n",
+ " \n",
+ " image_url_caption_pair = (image_url, captions)\n",
+ " return [(image_url_caption_pair, preprocessed_clip_input)]\n",
+ "\n",
+ "\n",
+ "class RankCLIPOutput(beam.DoFn):\n",
+ " \"\"\"\n",
+ " Process the output of CLIP to get the captions sorted by ranking order.\n",
+ "\n",
+ " The logits are the output of the CLIP model. Here, we apply a softmax activation\n",
+ " function to the logits to get the probabilistic distribution of the relevance\n",
+ " of each caption to the target image. After that, we sort the captions in descending\n",
+ " order with respect to the probabilities as a caption-probability pair. \n",
+ " \"\"\"\n",
+ "\n",
+ " def process(self, element : Tuple[Tuple[str, List[str]], Iterable[PredictionResult]]):\n",
+ " (image_url, captions), prediction = element\n",
+ " prediction_results = prediction.inference\n",
+ " prediction_probs = prediction_results.softmax(dim=-1).cpu().detach().numpy()\n",
+ " ranking = np.argsort(-prediction_probs)\n",
+ " sorted_caption_prob_pair = [(captions[idx], prediction_probs[idx]) for idx in ranking]\n",
+ "\n",
+ " return [(image_url, sorted_caption_prob_pair)]"
+ ],
+ "metadata": {
+ "id": "vS2D6VRqBH28"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Model handlers"
+ ],
+ "metadata": {
+ "id": "W1pwQk6ozzZ0"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "> ℹ️ Note that we will use a `KeyedModelHandler` for both models to attach a key to the general `ModelHandler`. The key is used to keep a reference of which image the inference is assoicated with, and it used in our post processing steps. In our case, we're using the `image_url` as the key."
+ ],
+ "metadata": {
+ "id": "BTmSPnjj8M2m"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "class PytorchNoBatchModelHandlerKeyedTensor(PytorchModelHandlerKeyedTensor):\n",
+ " \"\"\"Wrapper to PytorchModelHandler to limit batch size to 1.\n",
+ " The caption strings generated from BLIP tokenizer may have different\n",
+ " lengths, which doesn't work with torch.stack() in current RunInference\n",
+ " implementation since stack() requires tensors to be the same size.\n",
+ " Restricting max_batch_size to 1 means there is only 1 example per `batch`\n",
+ " in the run_inference() call.\n",
+ " \"\"\"\n",
+ " # This should be changed when this ticket is resolved https://github.com/apache/beam/issues/21863\n",
+ " def batch_elements_kwargs(self):\n",
+ " return {'max_batch_size': 1}"
+ ],
+ "metadata": {
+ "id": "OaR02_wxTMpc"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "> ℹ️ Note that we will use a `KeyedModelHandler` for both models to attach a key to the general `ModelHandler`. The key will be used for aggregation transforms of different inputs. "
+ ],
+ "metadata": {
+ "id": "gNLRO0EwvcGP"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## BLIP"
+ ],
+ "metadata": {
+ "id": "OXz7TuK4W_ZN"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "MAX_CAPTION_LENGTH = 80\n",
+ "MIN_CAPTION_LENGTH = 10\n",
+ "# Increasing beam search can improve the quality of the captions but results in\n",
+ "# more compute time\n",
+ "NUM_BEAMS = 1\n"
+ ],
+ "metadata": {
+ "id": "0npmJ7uSZN7w"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "class BLIPWrapper(torch.nn.Module):\n",
+ " \"\"\"\n",
+ " Wrapper around the BLIP model to overwrite the default \"forward\" method with the \"generate\" since BLIP uses the \n",
+ " \"generate\" method to produce the image captions.\n",
+ " \"\"\"\n",
+ " \n",
+ " def __init__(self, base_model: blip_decoder, num_beams: int, max_length: int,\n",
+ " min_length: int):\n",
+ " super().__init__()\n",
+ " self._model = base_model()\n",
+ " self._num_beams = num_beams\n",
+ " self._max_length = max_length\n",
+ " self._min_length = min_length\n",
+ "\n",
+ " def forward(self, inputs: torch.Tensor):\n",
+ " # squeeze because RunInference adds an extra dimension, which is empty\n",
+ " # This should be changed when this ticket is resolved https://github.com/apache/beam/issues/21863\n",
+ " inputs = inputs.squeeze(0)\n",
+ " captions = self._model.generate(inputs,\n",
+ " sample=True,\n",
+ " num_beams=self._num_beams,\n",
+ " max_length=self._max_length,\n",
+ " min_length=self._min_length)\n",
+ " return [captions]\n",
+ "\n",
+ " def load_state_dict(self, state_dict: dict):\n",
+ " self._model.load_state_dict(state_dict)\n",
+ "\n",
+ "\n",
+ "BLIP_model_handler = PytorchNoBatchModelHandlerKeyedTensor(\n",
+ " state_dict_path=blip_state_dict_path,\n",
+ " model_class=BLIPWrapper,\n",
+ " model_params={'base_model': blip_decoder, 'num_beams': NUM_BEAMS,\n",
+ " 'max_length': MAX_CAPTION_LENGTH, 'min_length': MIN_CAPTION_LENGTH},\n",
+ " device='GPU')\n",
+ "\n",
+ "BLIP_keyed_model_handler = KeyedModelHandler(BLIP_model_handler)"
+ ],
+ "metadata": {
+ "id": "RCKBJjujVw4q"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## CLIP"
+ ],
+ "metadata": {
+ "id": "-8PG_0txMiYA"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "class CLIPWrapper(CLIPModel):\n",
+ "\n",
+ " def forward(self, **kwargs: Dict[str, torch.Tensor]):\n",
+ " # squeeze because RunInference adds an extra dimension, which is empty\n",
+ " # This should be changed when this ticket is resolved https://github.com/apache/beam/issues/21863\n",
+ " kwargs = {key: tensor.squeeze(0) for key, tensor in kwargs.items()}\n",
+ " output = super().forward(**kwargs)\n",
+ " logits = output.logits_per_image\n",
+ " return logits\n",
+ "\n",
+ "\n",
+ "CLIP_model_handler = PytorchNoBatchModelHandlerKeyedTensor(\n",
+ " state_dict_path=clip_state_dict_path,\n",
+ " model_class=CLIPWrapper,\n",
+ " model_params={'config': CLIPConfig.from_pretrained(clip_model_config_path)},\n",
+ " device='GPU')\n",
+ "\n",
+ "CLIP_keyed_model_handler = KeyedModelHandler(CLIP_model_handler)\n"
+ ],
+ "metadata": {
+ "id": "EJw_OnZ1ZfuH"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Specify the images to display"
+ ],
+ "metadata": {
+ "id": "azC12uqDn0bq"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "images_url = ['https://storage.googleapis.com/apache-beam-samples/image_captioning/hamster_tea.jpg',\n",
+ " 'https://storage.googleapis.com/apache-beam-samples/image_captioning/potato_field.jpg',\n",
+ " 'https://storage.googleapis.com/apache-beam-samples/image_captioning/eiffel_tower_ballet_dancer.jpg']"
+ ],
+ "metadata": {
+ "id": "VJwE0bquoXOf"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Let's visualize the images that we will use for captioning "
+ ],
+ "metadata": {
+ "id": "c3fpgc15hzcq"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "for image_url in images_url:\n",
+ " response = requests.get(image_url)\n",
+ " image = Image.open(BytesIO(response.content)).convert('RGB')\n",
+ " fig = plt.figure()\n",
+ " title = os.path.basename(image_url).rsplit('.')[0]\n",
+ " fig.suptitle(title, fontsize=12)\n",
+ " plt.axis('off')\n",
+ " plt.imshow(image)\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 797
+ },
+ "id": "qTDcWfND9r1Y",
+ "outputId": "f0b096f8-a228-4b74-d335-c9fdb89746f3"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "<Figure size 432x288 with 1 Axes>"
+ ],
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAEECAYAAAB3HMxCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOy9aawkyXbf9zsRkVlV996+3T3rm5k3b+G+iYsMigZsUYBt2LQsw98sAV7gD4Y/GYbgFSBgmDZkGbBhyxZkCJAsSzakD4IWCKBtElpoAxREgSYEk5RNkXzvccR5y7zZevpuVZUZEccfIrKWzLyVVXfp7nkv/4PpW5WVGREZGXninP85cUJUlREjRowYcXuY592AESNGjPhWwShQR4wYMeKOMArUESNGjLgjjAJ1xIgRI+4Io0AdMWLEiDvCKFBHjBgx4o4wCtQRI0 [...]
+ },
+ "metadata": {
+ "needs_background": "light"
+ }
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "<Figure size 432x288 with 1 Axes>"
+ ],
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAEECAYAAAB3HMxCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOy9W6wtS3ff9RtV1T3nXGvvfc7Z57vZjv194CRI8IJ4ASICEQJxSSIegISrEFJASEi8ICIShPBDAjyAFMAvCARBsYgABSFkRRAiYRmBQJBIji0H29/9nO/c9tmXdZuX7qoaPIxR3T3nXmuf73KE7XgOae09Lz27q6tGjfEf1xZV5UxnOtOZzvSjU/jNHsCZznSmM/2NQmeBeqYznelMnxOdBeqZznSmM31OdBaoZzrTmc70OdFZoJ7pTGc60+dEZ4F6pjOd6UyfE5 [...]
+ },
+ "metadata": {
+ "needs_background": "light"
+ }
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "<Figure size 432x288 with 1 Axes>"
+ ],
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAAEECAYAAADTUyO4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOy9e7BvyXXX91nd+/c75z5mrjQPyRppJGNZlmzZPIyEy4aAeYWExElcFIQEkkCgKn+ESlUqFBQpAq6EAKmCokioxASncAFJUeZVBSmCEcFWGRuwsbELHMuWrJE0mpFG87pzn+f8frt75Y+1VnfvfX7n3DuSXbqY0zPnnvP77b1792M9vmv16tWiqlyWy3JZHr2SvtwNuCyX5bIcLpfMeVkuyyNaLpnzslyWR7RcMudluSyPaLlkzstyWR7Rcsmcl+WyPKLlkjkvy2 [...]
+ },
+ "metadata": {
+ "needs_background": "light"
+ }
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Initialize pipeline run parameters "
+ ],
+ "metadata": {
+ "id": "m8S8VQHvoEZf"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Number of Generated captions per image\n",
+ "NUM_CAPTIONS_PER_IMAGE = 10\n",
+ "\n",
+ "# Top captions to display\n",
+ "NUM_TOP_CAPTIONS_TO_DISPLAY = 3\n"
+ ],
+ "metadata": {
+ "id": "Dcz_M9GW0Kan"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Run pipeline"
+ ],
+ "metadata": {
+ "id": "5T9Pcdp7oNb8"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "> ℹ️ Note that we are using raw images from the `read_images` pipeline as input to both models. This is done because each model needs to preprocess the raw images differently (i.e. they require a different embedding representation for image captioning and image/captions pair ranking resp.).\n",
+ "\n",
+ "> ℹ️ We use `CoGroupByKey` to aggregate the raw images with the generated captions by their key (i.e. the image url). This process produces a tuple of image-captions pairs that is then passed to the CLIP transform and used for ranking."
+ ],
+ "metadata": {
+ "id": "G4a2ACIYeJyj"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "with beam.Pipeline() as pipeline:\n",
+ "\n",
+ " read_images = (\n",
+ " pipeline \n",
+ " | \"ReadUrl\" >> beam.Create(images_url)\n",
+ " | \"ReadImages\" >> beam.ParDo(ReadImagesFromUrl()))\n",
+ "\n",
+ " blip_caption_generation = (\n",
+ " read_images\n",
+ " | \"PreprocessBlipInput\" >> beam.ParDo(PreprocessBLIPInput(NUM_CAPTIONS_PER_IMAGE)) \n",
+ " | \"GenerateCaptions\" >> RunInference(BLIP_keyed_model_handler)\n",
+ " | \"PostprocessCaptions\" >> beam.ParDo(PostprocessBLIPOutput()))\n",
+ "\n",
+ " clip_captions_ranking = (\n",
+ " ({'image' : read_images, 'captions': blip_caption_generation})\n",
+ " | \"CreateImageCaptionPair\" >> beam.CoGroupByKey()\n",
+ " | \"PreprocessClipInput\" >> beam.ParDo(\n",
+ " PreprocessCLIPInput(\n",
+ " clip_feature_extractor_config_path,\n",
+ " clip_tokenizer_vocab_config_path,\n",
+ " clip_merges_config_path))\n",
+ " | \"GetRankingLogits\" >> RunInference(CLIP_keyed_model_handler)\n",
+ " | \"RankClipOutput\" >> beam.ParDo(RankCLIPOutput()))\n",
+ "\n",
+ " clip_captions_ranking | \"FormatCaptions\" >> beam.ParDo(FormatCaptions(NUM_TOP_CAPTIONS_TO_DISPLAY))\n",
+ " "
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "002e-FNbmuB8",
+ "outputId": "43e260f0-8419-447e-888b-c61e500ef391"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Image: hamster_tea\n",
+ "\tTop 3 captions ranked by CLIP:\n",
+ "\t\t1: a small white hamster inside of a mug. (Caption probability: 0.31)\n",
+ "\t\t2: the hamster is in a white mug on a blue surface. (Caption probability: 0.24)\n",
+ "\t\t3: a small hamster in a coffee cup. (Caption probability: 0.22)\n",
+ "\n",
+ "\n",
+ "Image: potato_field\n",
+ "\tTop 3 captions ranked by CLIP:\n",
+ "\t\t1: several potatoes are in the ground with a blue sky. (Caption probability: 0.64)\n",
+ "\t\t2: a bunch of potatoes are sitting in the dirt. (Caption probability: 0.09)\n",
+ "\t\t3: potato plants sprouts on muddy ground near forest. (Caption probability: 0.09)\n",
+ "\n",
+ "\n",
+ "Image: eiffel_tower_ballet_dancer\n",
+ "\tTop 3 captions ranked by CLIP:\n",
+ "\t\t1: a woman practices ballet in front of the eiffel tower. (Caption probability: 0.48)\n",
+ "\t\t2: a dancer is practicing in front of the eiffel tower. (Caption probability: 0.26)\n",
+ "\t\t3: a woman is doing ballet next to the eiffel tower. (Caption probability: 0.17)\n",
+ "\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# References\n",
+ "\n",
+ "* [RunInference API](https://beam.apache.org/documentation/sdks/python-machine-learning/) -- an official guide to the RunInference API.\n",
+ "* [RunInference Demo](https://colab.research.google.com/drive/10iPQTCmaLJL4_OohS00R9Wmor6d57JkS#scrollTo=ZVtBsKDgW1dl) -- a demo on ensemble model in colab\n",
+ "* [The advantages of having a DAG and what it unlocks for you](https://beam.apache.org/documentation/dsls/dataframes/differences-from-pandas) -- A guide on the advantages of using a Beam DAG for ML workflow orchestration and inference. "
+ ],
+ "metadata": {
+ "id": "HMH_ldJsrJoz"
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/website/www/site/content/en/documentation/ml/multi-model-pipelines.md b/website/www/site/content/en/documentation/ml/multi-model-pipelines.md
new file mode 100644
index 00000000000..be614e4b500
--- /dev/null
+++ b/website/www/site/content/en/documentation/ml/multi-model-pipelines.md
@@ -0,0 +1,101 @@
+---
+title: "Multi-model pipelines"
+---
+<!--
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+# Multi-model pipelines
+
+Apache Beam allows you to develop multi-model pipelines. In this specific scenario, you can ingest
+and transform some input data, run it through a model, and then pass the outcome of your first model
+into a second model. This page explains how multi-model pipelines work and gives an overview of what
+you need to know to build one.
+
+Before reading this section, it is recommended that you become familiar with the information in
+the [Pipeline development lifecycle](https://beam.apache.org/documentation/pipelines/design-your-pipeline/)
+.
+
+## How to build a Multi-model pipeline with Beam
+
+A typical machine learning workflow involves a series of data transformation steps such as data
+ingestion, data processing tasks, inference, and post-processing. Beam enables you to orchestrate
+all of those steps together by encapsulating them in a single Beam DAG. This allows you to build
+resilient and scalable end-to-end machine learning systems.
+
+To deploy your machine learning model in a Beam pipeline, you can use
+the [`RunInferenceAPI`](https://beam.apache.org/documentation/sdks/python-machine-learning/) which
+facilitates the integration of your model as a `PTransform` step in your DAG. Composing
+multiple `RunInference` transforms within a single DAG allows us to build a pipeline that consists
+of multiple ML models. This way Beam supports the development of complex ML systems.
+
+There are different patterns that can be used to build multi-model pipelines in Beam. Let’s have a
+look at a few of them.
+
+### A/B Pattern
+
+The A/B pattern describes a framework multiple where ML models are running in parallel. One
+application for this pattern is to test the performance of different machine learning models and
+decide whether a new model is an improvement over an existing one. This is also known as the
+“Champion/Challenger” method. Here, we typically define a business metric to compare the performance
+of a control model with the current model.
+
+An example could be recommendation engine models where you have an existing model that recommends
+ads based on the user’s preferences and activity history. When deciding to deploy a new model, you
+could split the incoming user traffic into two branches where half of the users are exposed to the
+new model and the other half to the current one.
+
+Afterwards, you could then measure the average click-through rate (CTR) of ads for both sets of
+users over a defined period of time to determine if the new model is performing better than the
+existing one.
+
+```
+import apache_beam as beam
+
+with beam.Pipeline() as pipeline:
+ userset_a_traffic, userset_b_traffic =
+ (pipeline | 'ReadFromStream' >> beam.ReadFromStream('stream_source')
+ | ‘Partition’ >> beam.partition(split_dataset, 2, ratio=[5, 5])
+ )
+
+model_a_predictions = userset_a_traffic | RunInference(<model_handler_A>)
+model_b_predictions = userset_b_traffic | RunInference(<model_handler_B>)
+```
+
+Where `beam.partition` is used to split the data source into 50/50 split partitions. For more
+information on data partitioning,
+see [Partition](https://beam.apache.org/documentation/transforms/python/elementwise/partition/).
+
+### Cascade Pattern
+
+The Cascade pattern is used to solve use-cases where the solution involves a series of ML models. In
+this scenario, the output of a model is typically transformed to a suitable format using
+a `PTransform` before passing it to another model.
+
+```
+with pipeline as p:
+ data = p | 'Read' >> beam.ReadFromSource('a_source')
+ model_a_predictions = data | RunInference(<model_handler_A>)
+ model_b_predictions = model_a_predictions | beam.ParDo(post_processing()) | RunInference(<model_handler_B>)
+```
+
+In
+this [notebook](https://github.com/apache/beam/tree/master/examples/notebooks/beam-ml/run-inference-multi-model.ipynb)
+, we show an end-to-end example of a cascade pipeline used for generating and ranking image
+captions. The solution consists of two open-source models:
+
+1. **A caption generation model ([BLIP](https://github.com/salesforce/BLIP))** that generates
+ candidate image captions from an input image.
+2. **A caption ranking model ([CLIP](https://github.com/openai/CLIP))** that uses the image and
+ candidate captions to rank the captions in the order in which they best describe the image.
+
diff --git a/website/www/site/content/en/documentation/ml/overview.md b/website/www/site/content/en/documentation/ml/overview.md
index 1c6e4e03fd0..0a0a83961b9 100755
--- a/website/www/site/content/en/documentation/ml/overview.md
+++ b/website/www/site/content/en/documentation/ml/overview.md
@@ -58,4 +58,5 @@ In order to automate and track the AI/ML workflows throughout your project, you
## Examples
You can find examples of end-to-end AI/ML pipelines for several use cases:
+* [Multi model pipelines in Beam](/documentation/ml/multi-model-pipelines)
* [Online Clustering in Beam](/documentation/ml/online-clustering)
diff --git a/website/www/site/content/en/documentation/sdks/python-machine-learning.md b/website/www/site/content/en/documentation/sdks/python-machine-learning.md
index cce9853990e..bcd430d0072 100644
--- a/website/www/site/content/en/documentation/sdks/python-machine-learning.md
+++ b/website/www/site/content/en/documentation/sdks/python-machine-learning.md
@@ -41,7 +41,7 @@ Using the `Shared` class within the RunInference implementation makes it possibl
### Multi-model pipelines
-The RunInference API can be composed into multi-model pipelines. Multi-model pipelines can be useful for A/B testing or for building out ensembles made up of models that perform tokenization, sentence segmentation, part-of-speech tagging, named entity extraction, language detection, coreference resolution, and more.
+The RunInference API can be composed into multi-model pipelines. Multi-model pipelines can be useful for A/B testing or for building out cascade models made up of models that perform tokenization, sentence segmentation, part-of-speech tagging, named entity extraction, language detection, coreference resolution, and more.
## Modify a pipeline to use an ML model
@@ -107,7 +107,7 @@ with pipeline as p:
Where `model_handler_A` and `model_handler_B` are the model handler setup code.
-#### Ensemble Pattern
+#### Cascade Pattern
```
with pipeline as p:
@@ -124,7 +124,7 @@ When using multiple models in a single pipeline, different models may have diffe
Resource hints allow you to provide information to a runner about the compute resource requirements for each step in your
pipeline.
-For example, the following snippet extends the previous ensemble pattern with hints for each RunInference call
+For example, the following snippet extends the previous cascade pattern with hints for each RunInference call
to specify RAM and hardware accelerator requirements:
```
diff --git a/website/www/site/layouts/partials/section-menu/en/documentation.html b/website/www/site/layouts/partials/section-menu/en/documentation.html
index f55e30f6c4d..ebfc97f3c12 100644
--- a/website/www/site/layouts/partials/section-menu/en/documentation.html
+++ b/website/www/site/layouts/partials/section-menu/en/documentation.html
@@ -211,8 +211,10 @@
</li>
<li class="section-nav-item--collapsible">
<span class="section-nav-list-title">AI/ML pipelines</span>
+
<ul class="section-nav-list">
<li><a href="/documentation/ml/overview/">Overview</a></li>
+ <li><a href="/documentation/ml/multi-model-pipelines/">Multi-model pipelines</a></li>
<li><a href="/documentation/ml/online-clustering/">Online Clustering</a></li>
</ul>
</li>