You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2023/06/29 14:33:45 UTC

[beam] branch users/damccorm/multimodel-notebook-refresh created (now 46eaaa7f76d)

This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a change to branch users/damccorm/multimodel-notebook-refresh
in repository https://gitbox.apache.org/repos/asf/beam.git


      at 46eaaa7f76d Update multi model notebook to remove workarounds

This branch includes the following new commits:

     new 46eaaa7f76d Update multi model notebook to remove workarounds

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



[beam] 01/01: Update multi model notebook to remove workarounds

Posted by da...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch users/damccorm/multimodel-notebook-refresh
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 46eaaa7f76d2846a8b7f63ab98f81450787f5b87
Author: Danny McCormick <da...@google.com>
AuthorDate: Thu Jun 29 10:33:25 2023 -0400

    Update multi model notebook to remove workarounds
---
 .../beam-ml/run_inference_multi_model.ipynb        | 281 ++++++++-------------
 1 file changed, 106 insertions(+), 175 deletions(-)

diff --git a/examples/notebooks/beam-ml/run_inference_multi_model.ipynb b/examples/notebooks/beam-ml/run_inference_multi_model.ipynb
index 9a99ad2cf47..430c22b9a2b 100644
--- a/examples/notebooks/beam-ml/run_inference_multi_model.ipynb
+++ b/examples/notebooks/beam-ml/run_inference_multi_model.ipynb
@@ -47,8 +47,7 @@
     {
       "cell_type": "markdown",
       "source": [
-        "# Ensemble model using an image captioning and ranking example",
-        "\n",
+        "# Ensemble model using an image captioning and ranking example\n",
         "<table align=\"left\">\n",
         "  <td>\n",
         "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_multi_model.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n",
@@ -65,8 +64,8 @@
     {
       "cell_type": "markdown",
       "source": [
-        "A single machine learning model might not be the right solution for your task. Often, machine learning model tasks involve aggregating mutliple models together to produce one optimal predictive model and to boost performance. \n",
-        " \n",
+        "A single machine learning model might not be the right solution for your task. Often, machine learning model tasks involve aggregating mutliple models together to produce one optimal predictive model and to boost performance.\n",
+        "\n",
         "\n",
         "This notebook shows how to implement a cascade model in Apache Beam using the [RunInference API](https://beam.apache.org/documentation/sdks/python-machine-learning/). The RunInference API enables you to run your Beam transforms as part of your pipeline for optimal machine learning inference.\n",
         "\n",
@@ -94,7 +93,7 @@
         "\n",
         "This example shows how to generate captions on a a large set of images. Apache Beam is the ideal tool to handle this workflow. We use two models for this task:\n",
         "\n",
-        "* [BLIP](https://github.com/salesforce/BLIP): Generates a set of candidate captions for a given image. \n",
+        "* [BLIP](https://github.com/salesforce/BLIP): Generates a set of candidate captions for a given image.\n",
         "* [CLIP](https://github.com/openai/CLIP): Ranks the generated captions based on accuracy."
       ],
       "metadata": {
@@ -119,7 +118,7 @@
         "* Run inference with BLIP to generate a list of caption candidates.\n",
         "* Aggregate the generated captions with their source image.\n",
         "* Preprocess the aggregated image-caption pairs to rank them with CLIP.\n",
-        "* Run inference with CLIP to generate the caption ranking. \n",
+        "* Run inference with CLIP to generate the caption ranking.\n",
         "* Print the image names and the captions sorted according to their ranking.\n",
         "\n",
         "\n",
@@ -139,13 +138,13 @@
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/",
-          "height": 440
+          "height": 460
         },
         "id": "3suC5woJLW_N",
-        "outputId": "d2f9f67b-361b-4ae9-f9db-ce2ff9abd509",
+        "outputId": "2b5e78bf-f212-4a77-9325-8808ef024c2e",
         "cellView": "form"
       },
-      "execution_count": null,
+      "execution_count": 1,
       "outputs": [
         {
           "output_type": "execute_result",
@@ -158,7 +157,7 @@
             ]
           },
           "metadata": {},
-          "execution_count": 3
+          "execution_count": 1
         }
       ]
     },
@@ -184,68 +183,34 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 32,
+      "execution_count": 1,
       "metadata": {
-        "id": "tTUZpG9_q-OW",
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "outputId": "9ee6407a-8e4b-4520-fe5d-54a886b6e0b1"
+        "id": "tTUZpG9_q-OW"
       },
-      "outputs": [
-        {
-          "output_type": "stream",
-          "name": "stdout",
-          "text": [
-            "\u001b[K     |████████████████████████████████| 2.1 MB 7.0 MB/s \n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m47.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.3/3.3 MB\u001b[0m \u001b[31m90.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m182.4/182.4 kB\u001b[0m \u001b[31m21.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m880.6/880.6 kB\u001b[0m \u001b[31m69.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
-            "  Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
-            "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m377.0/377.0 kB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.1/53.1 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.5/6.5 MB\u001b[0m \u001b[31m60.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.8/12.8 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m235.4/235.4 kB\u001b[0m \u001b[31m8.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[?25h  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
-            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
-            "  Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n",
-            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
-            "  Building wheel for fairscale (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
-            "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
-            "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
-            "\u001b[0m"
-          ]
-        }
-      ],
+      "outputs": [],
       "source": [
         "!pip install --upgrade pip --quiet\n",
-        "!pip install transformers==4.15.0 --quiet\n",
+        "!pip install transformers==4.30.2 --quiet\n",
         "!pip install timm==0.4.12 --quiet\n",
         "!pip install ftfy==6.1.1 --quiet\n",
         "!pip install spacy==3.4.1 --quiet\n",
         "!pip install fairscale==0.4.4 --quiet\n",
-        "!pip install apache_beam[gcp]>=2.40.0  \n",
+        "!pip install apache_beam[gcp]>=2.48.0\n",
         "\n",
         "# To use the newly installed versions, restart the runtime.\n",
-        "exit() "
+        "exit()"
       ]
     },
     {
       "cell_type": "code",
       "source": [
         "import requests\n",
-        "import os \n",
+        "import os\n",
         "import urllib\n",
-        "import json  \n",
+        "import json\n",
         "import io\n",
         "from io import BytesIO\n",
+        "from typing import Sequence\n",
         "from typing import Iterator\n",
         "from typing import Iterable\n",
         "from typing import Tuple\n",
@@ -303,7 +268,7 @@
           "base_uri": "https://localhost:8080/"
         },
         "id": "Ud4sUXV2x8LO",
-        "outputId": "9e12ea04-a347-426f-8145-280a5676e78b"
+        "outputId": "cc814ff8-d424-4880-e006-56803e0508aa"
       },
       "execution_count": 2,
       "outputs": [
@@ -311,7 +276,7 @@
           "output_type": "stream",
           "name": "stdout",
           "text": [
-            "Error: Failed to call git rev-parse --git-dir --show-toplevel: \"fatal: not a git repository (or any of the parent directories): .git\\n\"\n",
+            "Error: Failed to call git rev-parse --git-dir: exit status 128 \n",
             "Git LFS initialized.\n",
             "Cloning into 'clip-vit-base-patch32'...\n",
             "remote: Enumerating objects: 51, done.\u001b[K\n",
@@ -362,7 +327,7 @@
           "base_uri": "https://localhost:8080/"
         },
         "id": "g4-6WwqUtxea",
-        "outputId": "3b04b933-aab0-4f5b-c967-ed784125bc6a"
+        "outputId": "29112ca0-f111-48b7-d8cc-a4e04fb7a02b"
       },
       "execution_count": 4,
       "outputs": [
@@ -388,8 +353,8 @@
         "from BLIP.models.blip import blip_decoder\n",
         "\n",
         "!gdown 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'\n",
-        "# The blip model is saved as a checkoint, load it and save it as a state dict since RunInference required \n",
-        "# a state dict for model instantiation \n",
+        "# The blip model is saved as a checkpoint, load it and save it as a state dict since RunInference required\n",
+        "# a state dict for model instantiation\n",
         "blip_state_dict_path = '/content/BLIP/blip_state_dict.pth'\n",
         "torch.save(torch.load('/content/BLIP/model*_base_caption.pth')['model'], blip_state_dict_path)"
       ],
@@ -398,7 +363,7 @@
           "base_uri": "https://localhost:8080/"
         },
         "id": "GCvOP_iZh41c",
-        "outputId": "224c22b1-eda6-463c-c926-1341ec9edef8"
+        "outputId": "a96f0ff5-cdf7-4394-be6e-d5bfca2f3a1f"
       },
       "execution_count": 5,
       "outputs": [
@@ -409,7 +374,7 @@
             "Downloading...\n",
             "From: https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth\n",
             "To: /content/BLIP/model*_base_caption.pth\n",
-            "100% 896M/896M [00:04<00:00, 198MB/s] \n"
+            "100% 896M/896M [00:04<00:00, 198MB/s]\n"
           ]
         }
       ]
@@ -500,9 +465,9 @@
         "\n",
         "  \"\"\"\n",
         "  Process the raw image input to a format suitable for BLIP inference. The processed\n",
-        "  images are duplicated to the number of desired captions per image. \n",
+        "  images are duplicated to the number of desired captions per image.\n",
         "\n",
-        "  Preprocessing transformation taken from: \n",
+        "  Preprocessing transformation taken from:\n",
         "  https://github.com/salesforce/BLIP/blob/d10be550b2974e17ea72e74edc7948c9e5eab884/predict.py\n",
         "  \"\"\"\n",
         "\n",
@@ -510,7 +475,7 @@
         "    self._captions_per_image = captions_per_image\n",
         "\n",
         "  def setup(self):\n",
-        "    \n",
+        "\n",
         "    # Initialize the image transformer.\n",
         "    self._transform = transforms.Compose([\n",
         "      transforms.Resize((384, 384),interpolation=InterpolationMode.BICUBIC),\n",
@@ -519,7 +484,7 @@
         "    ])\n",
         "\n",
         "  def process(self, element):\n",
-        "    image_url, image = element \n",
+        "    image_url, image = element\n",
         "    # The following lines provide a workaround to turn off BatchElements.\n",
         "    preprocessed_img = self._transform(image).unsqueeze(0)\n",
         "    preprocessed_img = preprocessed_img.repeat(self._captions_per_image, 1, 1, 1)\n",
@@ -533,7 +498,7 @@
         "  Process the PredictionResult to get the generated image captions\n",
         "  \"\"\"\n",
         "  def process(self, element : Tuple[str, Iterable[PredictionResult]]):\n",
-        "    image_url, prediction = element \n",
+        "    image_url, prediction = element\n",
         "\n",
         "    return [(image_url, prediction.inference)]"
       ],
@@ -546,7 +511,7 @@
     {
       "cell_type": "markdown",
       "source": [
-        "### Define CLIP functions \n",
+        "### Define CLIP functions\n",
         "\n",
         "Define the preprocessing and postprocessing functions for CLIP."
       ],
@@ -560,9 +525,9 @@
         "class PreprocessCLIPInput(beam.DoFn):\n",
         "\n",
         "  \"\"\"\n",
-        "  Process the image-caption pair to a format suitable for CLIP inference. \n",
+        "  Process the image-caption pair to a format suitable for CLIP inference.\n",
         "\n",
-        "  After grouping the raw images with the generated captions, we need to \n",
+        "  After grouping the raw images with the generated captions, we need to\n",
         "  preprocess them before passing them to the ranking stage (CLIP model).\n",
         "  \"\"\"\n",
         "\n",
@@ -572,12 +537,12 @@
         "               merges_file_config_path: str):\n",
         "\n",
         "    self._feature_extractor_config_path = feature_extractor_config_path\n",
-        "    self._tokenizer_vocab_config_path = tokenizer_vocab_config_path \n",
+        "    self._tokenizer_vocab_config_path = tokenizer_vocab_config_path\n",
         "    self._merges_file_config_path = merges_file_config_path\n",
         "\n",
         "\n",
         "  def setup(self):\n",
-        "    \n",
+        "\n",
         "    # Initialize the CLIP feature extractor.\n",
         "    feature_extractor_config = CLIPConfig.from_pretrained(self._feature_extractor_config_path)\n",
         "    feature_extractor = CLIPFeatureExtractor(feature_extractor_config)\n",
@@ -585,14 +550,14 @@
         "    # Initialize the CLIP tokenizer.\n",
         "    tokenizer = CLIPTokenizer(self._tokenizer_vocab_config_path,\n",
         "                              self._merges_file_config_path)\n",
-        "    \n",
+        "\n",
         "    # Initialize the CLIP processor used to process the image-caption pair.\n",
         "    self._processor = CLIPProcessor(feature_extractor=feature_extractor,\n",
         "                                    tokenizer=tokenizer)\n",
         "\n",
         "  def process(self, element: Tuple[str, Dict[str, List[Any]]]):\n",
         "\n",
-        "    image_url, image_captions_pair = element \n",
+        "    image_url, image_captions_pair = element\n",
         "    # Unpack the image and captions after grouping them with 'CoGroupByKey()'.\n",
         "    image = image_captions_pair['image'][0]\n",
         "    captions = image_captions_pair['captions'][0]\n",
@@ -600,7 +565,7 @@
         "                                              text = captions,\n",
         "                                              return_tensors=\"pt\",\n",
         "                                              padding=True)\n",
-        "    \n",
+        "\n",
         "    image_url_caption_pair = (image_url, captions)\n",
         "    return [(image_url_caption_pair, preprocessed_clip_input)]\n",
         "\n",
@@ -612,7 +577,7 @@
         "  The logits are the output of the CLIP model. Here, we apply a softmax activation\n",
         "  function to the logits to get the probabilistic distribution of the relevance\n",
         "  of each caption to the target image. After that, we sort the captions in descending\n",
-        "  order with respect to the probabilities as a caption-probability pair. \n",
+        "  order with respect to the probabilities as a caption-probability pair.\n",
         "  \"\"\"\n",
         "\n",
         "  def process(self, element : Tuple[Tuple[str, List[str]], Iterable[PredictionResult]]):\n",
@@ -642,7 +607,9 @@
     {
       "cell_type": "markdown",
       "source": [
-        "Use a `KeyedModelHandler` for both models to attach a key to the general `ModelHandler`.\n",
+        "A `ModelHandler` is Beam's method for defining the configuration needed to load and invoke your model. Since both the BLIP and CLIP models use Pytorch and take KeyedTensors as inputs, we will use `PytorchModelHandlerKeyedTensor` for both.\n",
+        "\n",
+        "We will use a `KeyedModelHandler` for both models to attach a key to the general `ModelHandler`.\n",
         "The key is used for the following purposes:\n",
         "* To keep a reference to the image that the inference is associated with.\n",
         "* To aggregate transforms of different inputs.\n",
@@ -654,36 +621,6 @@
         "id": "BTmSPnjj8M2m"
       }
     },
-    {
-      "cell_type": "code",
-      "source": [
-        "class PytorchNoBatchModelHandlerKeyedTensor(PytorchModelHandlerKeyedTensor):\n",
-        "      \"\"\"Wrapper to PytorchModelHandler to limit batch size to 1.\n",
-        "    The caption strings generated from the BLIP tokenizer might have different\n",
-        "    lengths. Different length strings don't work with torch.stack() in the current RunInference\n",
-        "    implementation, because stack() requires tensors to be the same size.\n",
-        "    Restricting max_batch_size to 1 means there is only 1 example per `batch`\n",
-        "    in the run_inference() call.\n",
-        "    \"\"\"\n",
-        "      # The following lines provide a workaround to turn off BatchElements.\n",
-        "      def batch_elements_kwargs(self):\n",
-        "          return {'max_batch_size': 1}"
-      ],
-      "metadata": {
-        "id": "OaR02_wxTMpc"
-      },
-      "execution_count": 9,
-      "outputs": []
-    },
-    {
-      "cell_type": "markdown",
-      "source": [
-        "Note that we use a `KeyedModelHandler` for both models to attach a key to the general `ModelHandler`. The key is used for aggregation transforms of different inputs."
-      ],
-      "metadata": {
-        "id": "gNLRO0EwvcGP"
-      }
-    },
     {
       "cell_type": "markdown",
       "source": [
@@ -713,48 +650,36 @@
     {
       "cell_type": "code",
       "source": [
-        "class BLIPWrapper(torch.nn.Module):\n",
-        "  \"\"\"\n",
-        "   Wrapper around the BLIP model to overwrite the default \"forward\" method with the \"generate\" method, because BLIP uses the \n",
-        "  \"generate\" method to produce the image captions.\n",
-        "  \"\"\"\n",
-        "  \n",
-        "  def __init__(self, base_model: blip_decoder, num_beams: int, max_length: int,\n",
-        "                min_length: int):\n",
-        "    super().__init__()\n",
-        "    self._model = base_model()\n",
-        "    self._num_beams = num_beams\n",
-        "    self._max_length = max_length\n",
-        "    self._min_length = min_length\n",
-        "\n",
-        "  def forward(self, inputs: torch.Tensor):\n",
-        "    # Squeeze because RunInference adds an extra dimension, which is empty.\n",
-        "    # The following lines provide a workaround to turn off BatchElements.\n",
-        "    inputs = inputs.squeeze(0)\n",
-        "    captions = self._model.generate(inputs,\n",
-        "                                    sample=True,\n",
-        "                                    num_beams=self._num_beams,\n",
-        "                                    max_length=self._max_length,\n",
-        "                                    min_length=self._min_length)\n",
-        "    return [captions]\n",
-        "\n",
-        "  def load_state_dict(self, state_dict: dict):\n",
-        "    self._model.load_state_dict(state_dict)\n",
-        "\n",
-        "\n",
-        "BLIP_model_handler = PytorchNoBatchModelHandlerKeyedTensor(\n",
+        "def blip_keyed_tensor_inference_fn(\n",
+        "    batch: Sequence[Dict[str, torch.Tensor]],\n",
+        "    model: torch.nn.Module,\n",
+        "    device: str,\n",
+        "    inference_args: Optional[Dict[str, Any]] = None,\n",
+        "    model_id: Optional[str] = None,\n",
+        ") -> Iterable[PredictionResult]:\n",
+        "  # By default, Beam batches inputs for bulk inference and calls model(batch)\n",
+        "  # Since we want to call model.generate on a single unbatched input (BLIP/CLIP\n",
+        "  # don't handle batched inputs), we define a custom inference function.\n",
+        "  captions = model.generate(batch[0]['inputs'],\n",
+        "                            sample=True,\n",
+        "                            num_beams=NUM_BEAMS,\n",
+        "                            max_length=MAX_CAPTION_LENGTH,\n",
+        "                            min_length=MIN_CAPTION_LENGTH)\n",
+        "  return [PredictionResult(batch[0], captions, model_id)]\n",
+        "\n",
+        "\n",
+        "BLIP_model_handler = PytorchModelHandlerKeyedTensor(\n",
         "    state_dict_path=blip_state_dict_path,\n",
-        "    model_class=BLIPWrapper,\n",
-        "    model_params={'base_model': blip_decoder, 'num_beams': NUM_BEAMS,\n",
-        "                  'max_length': MAX_CAPTION_LENGTH, 'min_length': MIN_CAPTION_LENGTH},\n",
-        "    device='GPU')\n",
+        "    model_class=blip_decoder,\n",
+        "    inference_fn=blip_keyed_tensor_inference_fn,\n",
+        "    max_batch_size=1)\n",
         "\n",
         "BLIP_keyed_model_handler = KeyedModelHandler(BLIP_model_handler)"
       ],
       "metadata": {
         "id": "RCKBJjujVw4q"
       },
-      "execution_count": 11,
+      "execution_count": 10,
       "outputs": []
     },
     {
@@ -771,29 +696,33 @@
     {
       "cell_type": "code",
       "source": [
-        "class CLIPWrapper(CLIPModel):\n",
-        "\n",
-        "  def forward(self, **kwargs: Dict[str, torch.Tensor]):\n",
-        "    # Squeeze because RunInference adds an extra dimension, which is empty.\n",
-        "    # The following lines provide a workaround to turn off BatchElements.\n",
-        "    kwargs = {key: tensor.squeeze(0) for key, tensor in kwargs.items()}\n",
-        "    output = super().forward(**kwargs)\n",
-        "    logits = output.logits_per_image\n",
-        "    return logits\n",
-        "\n",
-        "\n",
-        "CLIP_model_handler = PytorchNoBatchModelHandlerKeyedTensor(\n",
+        "def clip_keyed_tensor_inference_fn(\n",
+        "    batch: Sequence[Dict[str, torch.Tensor]],\n",
+        "    model: torch.nn.Module,\n",
+        "    device: str,\n",
+        "    inference_args: Optional[Dict[str, Any]] = None,\n",
+        "    model_id: Optional[str] = None,\n",
+        ") -> Iterable[PredictionResult]:\n",
+        "  # By default, Beam batches inputs for bulk inference and calls model(batch)\n",
+        "  # Since we want to call model on a single unbatched input (BLIP/CLIP don't\n",
+        "  # handle batched inputs), we define a custom inference function.\n",
+        "  output = model(**batch[0], **inference_args)\n",
+        "  return [PredictionResult(batch[0], output.logits_per_image[0], model_id)]\n",
+        "\n",
+        "\n",
+        "CLIP_model_handler = PytorchModelHandlerKeyedTensor(\n",
         "    state_dict_path=clip_state_dict_path,\n",
-        "    model_class=CLIPWrapper,\n",
+        "    model_class=CLIPModel,\n",
         "    model_params={'config': CLIPConfig.from_pretrained(clip_model_config_path)},\n",
-        "    device='GPU')\n",
+        "    inference_fn=clip_keyed_tensor_inference_fn,\n",
+        "    max_batch_size=1)\n",
         "\n",
         "CLIP_keyed_model_handler = KeyedModelHandler(CLIP_model_handler)\n"
       ],
       "metadata": {
         "id": "EJw_OnZ1ZfuH"
       },
-      "execution_count": 12,
+      "execution_count": 11,
       "outputs": []
     },
     {
@@ -817,7 +746,7 @@
       "metadata": {
         "id": "VJwE0bquoXOf"
       },
-      "execution_count": 13,
+      "execution_count": 12,
       "outputs": []
     },
     {
@@ -834,7 +763,7 @@
       "source": [
         "#@title\n",
         "license_txt_url = 'https://storage.googleapis.com/apache-beam-samples/image_captioning/LICENSE.txt'\n",
-        "license_dict = json.loads(urllib.request.urlopen(license_txt_url).read().decode(\"utf-8\")) \n",
+        "license_dict = json.loads(urllib.request.urlopen(license_txt_url).read().decode(\"utf-8\"))\n",
         "\n",
         "for image_url in images_url:\n",
         "  response = requests.get(image_url)\n",
@@ -855,7 +784,7 @@
         "outputId": "6e771e4e-a76a-4855-b466-976cdf35b506",
         "cellView": "form"
       },
-      "execution_count": 16,
+      "execution_count": null,
       "outputs": [
         {
           "output_type": "display_data",
@@ -918,7 +847,7 @@
       "metadata": {
         "id": "Dcz_M9GW0Kan"
       },
-      "execution_count": 14,
+      "execution_count": 13,
       "outputs": []
     },
     {
@@ -947,13 +876,13 @@
         "with beam.Pipeline() as pipeline:\n",
         "\n",
         "  read_images = (\n",
-        "            pipeline \n",
+        "            pipeline\n",
         "            | \"ReadUrl\" >> beam.Create(images_url)\n",
         "            | \"ReadImages\" >> beam.ParDo(ReadImagesFromUrl()))\n",
         "\n",
         "  blip_caption_generation = (\n",
         "            read_images\n",
-        "            | \"PreprocessBlipInput\" >> beam.ParDo(PreprocessBLIPInput(NUM_CAPTIONS_PER_IMAGE)) \n",
+        "            | \"PreprocessBlipInput\" >> beam.ParDo(PreprocessBLIPInput(NUM_CAPTIONS_PER_IMAGE))\n",
         "            | \"GenerateCaptions\" >> RunInference(BLIP_keyed_model_handler)\n",
         "            | \"PostprocessCaptions\" >> beam.ParDo(PostprocessBLIPOutput()))\n",
         "\n",
@@ -966,19 +895,21 @@
         "                    clip_tokenizer_vocab_config_path,\n",
         "                    clip_merges_config_path))\n",
         "            | \"GetRankingLogits\" >> RunInference(CLIP_keyed_model_handler)\n",
-        "            | \"RankClipOutput\" >> beam.ParDo(RankCLIPOutput()))\n",
+        "            | \"RankClipOutput\" >> beam.ParDo(RankCLIPOutput())\n",
+        "            )\n",
         "\n",
         "  clip_captions_ranking | \"FormatCaptions\" >> beam.ParDo(FormatCaptions(NUM_TOP_CAPTIONS_TO_DISPLAY))\n",
-        "  "
+        ""
       ],
       "metadata": {
         "colab": {
-          "base_uri": "https://localhost:8080/"
+          "base_uri": "https://localhost:8080/",
+          "height": 428
         },
         "id": "002e-FNbmuB8",
-        "outputId": "49c646f1-8612-433f-b134-ea8af0ff5591"
+        "outputId": "1b540b1e-b146-45d6-f8d3-ccaf461a87b7"
       },
-      "execution_count": 18,
+      "execution_count": 14,
       "outputs": [
         {
           "output_type": "stream",
@@ -986,23 +917,23 @@
           "text": [
             "Image: Paris-sunset\n",
             "\tTop 3 captions ranked by CLIP:\n",
-            "\t\t1: the eiffel tower in paris is silhouetted at sunset. (Caption probability: 0.23)\n",
-            "\t\t2: the sun sets over the city of paris, with the eiffel tower in the distance. (Caption probability: 0.19)\n",
-            "\t\t3: the sun sets over the eiffel tower in paris. (Caption probability: 0.17)\n",
+            "\t\t1: the setting sun is reflected in an orange setting sky over paris. (Caption probability: 0.28)\n",
+            "\t\t2: the sun rising above the eiffel tower over paris. (Caption probability: 0.23)\n",
+            "\t\t3: the sun setting over the eiffel tower and rooftops. (Caption probability: 0.15)\n",
             "\n",
             "\n",
             "Image: Wedges\n",
             "\tTop 3 captions ranked by CLIP:\n",
-            "\t\t1: a basket of baked fries with a sauce in it. (Caption probability: 0.60)\n",
-            "\t\t2: cooked french fries with ketchup and dip sitting in napkin. (Caption probability: 0.16)\n",
-            "\t\t3: some french fries with dipping sauce on the side. (Caption probability: 0.08)\n",
+            "\t\t1: sweet potato fries with ketchup served in bowl. (Caption probability: 0.73)\n",
+            "\t\t2: this is a plate of sweet potato fries with ketchup. (Caption probability: 0.16)\n",
+            "\t\t3: sweet potato fries and a dipping sauce are on the tray. (Caption probability: 0.06)\n",
             "\n",
             "\n",
             "Image: Hamsters\n",
             "\tTop 3 captions ranked by CLIP:\n",
-            "\t\t1: a person petting two small hamsters while in their home. (Caption probability: 0.51)\n",
-            "\t\t2: a woman holding two small white baby animals. (Caption probability: 0.23)\n",
-            "\t\t3: a hand holding a small mouse that looks tiny. (Caption probability: 0.09)\n",
+            "\t\t1: person holding two small animals in their hands. (Caption probability: 0.62)\n",
+            "\t\t2: a person's hand holding a small hamster in front of them. (Caption probability: 0.20)\n",
+            "\t\t3: a person holding a small animal in their hands. (Caption probability: 0.09)\n",
             "\n",
             "\n"
           ]